Skip to content

Commit

Permalink
Optimise decimal casting for infallible conversions (#7021)
Browse files Browse the repository at this point in the history
* Implement fast path for infallible case.

* Add benchmark for infallible casting to smaller scale.

* Format benchmark.

* Improve test cases.

* Simplify implementation to reduce code duplication.

* Add error test case.

* Add extensive test suite.

* Remove specialization.

* Add example as comment.

* Reorder fields.

* Check preconditions on precision and scale.

* Test decimal cast between multiple types.

* Check against error message template.

* Move validation into infallible branch.

* Use generic function instead of macro to run tests.

* Add clarifying comment about unwrapping.

* Fix typo.

* Add test case for input/output scale = 0.

* Fix clippy.
  • Loading branch information
aweltsch authored Feb 24, 2025
1 parent 493d3ee commit d54a8c6
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 7 deletions.
51 changes: 46 additions & 5 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ where

pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
array: &PrimitiveArray<I>,
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
Expand All @@ -100,9 +101,22 @@ where
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
let delta_scale = input_scale - output_scale;
// if the reduction of the input number through scaling (dividing) is greater
// than a possible precision loss (plus potential increase via rounding)
// every input number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
// The rounding may add an additional digit, so the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);

let div = I::Native::from_decimal(10_i128)
.unwrap()
.pow_checked((input_scale - output_scale) as u32)?;
.pow_checked(delta_scale as u32)?;

let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
let half_neg = half.neg_wrapping();
Expand All @@ -121,7 +135,13 @@ where
O::Native::from_decimal(adjusted)
};

Ok(if cast_options.safe {
Ok(if is_infallible_cast {
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed
// to fit into the target type
array.unary(g)
} else if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
Expand All @@ -133,6 +153,7 @@ where

pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
array: &PrimitiveArray<I>,
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
Expand All @@ -145,13 +166,27 @@ where
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
let delta_scale = output_scale - input_scale;
let mul = O::Native::from_decimal(10_i128)
.unwrap()
.pow_checked((output_scale - input_scale) as u32)?;

.pow_checked(delta_scale as u32)?;

// if the gain in precision (digits) is greater than the multiplication due to scaling
// every number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then an increase of scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
// needs to provide at least 8 digits precision
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());

Ok(if cast_options.safe {
Ok(if is_infallible_cast {
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
// unwrapping is safe since the result is guaranteed to fit into the target type
let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
array.unary(f)
} else if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
Expand Down Expand Up @@ -180,14 +215,17 @@ where
} else if input_scale <= output_scale {
convert_to_bigger_or_equal_scale_decimal::<T, T>(
array,
input_precision,
input_scale,
output_precision,
output_scale,
cast_options,
)?
} else {
// input_scale > output_scale
convert_to_smaller_scale_decimal::<T, T>(
array,
input_precision,
input_scale,
output_precision,
output_scale,
Expand All @@ -204,6 +242,7 @@ where
// Support two different types of decimal cast operations
pub(crate) fn cast_decimal_to_decimal<I, O>(
array: &PrimitiveArray<I>,
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
Expand All @@ -218,6 +257,7 @@ where
let array: PrimitiveArray<O> = if input_scale > output_scale {
convert_to_smaller_scale_decimal::<I, O>(
array,
input_precision,
input_scale,
output_precision,
output_scale,
Expand All @@ -226,6 +266,7 @@ where
} else {
convert_to_bigger_or_equal_scale_decimal::<I, O>(
array,
input_precision,
input_scale,
output_precision,
output_scale,
Expand Down
Loading

0 comments on commit d54a8c6

Please sign in to comment.