From eb15e82a2bfb2c162023b718c62f36661f5d7c66 Mon Sep 17 00:00:00 2001 From: rofinn Date: Mon, 8 Apr 2024 17:13:36 -0700 Subject: [PATCH] Add a rounding mode and dispatch when calculating increment. 1. We add a `r::Union{RoundingMode, Nothing}=nothing` to `Interpolate`, avoiding any breaking changes. 2. Added unit tests for unsigned integers and cases with and without the RoundingMode. 3. Increment calculation: - Still using `(next - prev) / n` in the default case - If a rounding mode is specified then `div` is used with the corresponding mode (avoids `InexactError`) - If we're given unsigned ints then we convert them to integers for the increment calculation (avoids integer overflow) --- src/imputors/interp.jl | 22 +++++++++++++--------- test/imputors/interp.jl | 27 ++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index 9329021..cccdc78 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -1,5 +1,5 @@ """ - Interpolate(; limit=nothing) + Interpolate(; limit=nothing, r=nothing) Performs linear interpolation between the nearest values in an vector. The current implementation is univariate, so each variable in a table or matrix will @@ -11,6 +11,8 @@ that all missing values will be imputed. # Keyword Arguments * `limit::Union{UInt, Nothing}`: Optionally limit the gap sizes that can be interpolated. +* `r::Union{RoundingMode, Nothing}`: Optionally specify a rounding mode. + Avoids `InexactError`s when interpolating over integers. # Example ```jldoctest @@ -34,9 +36,10 @@ julia> impute(M, Interpolate(; limit=2); dims=:rows) """ struct Interpolate <: Imputor limit::Union{UInt, Nothing} + r::Union{RoundingMode, Nothing} end -Interpolate(; limit=nothing) = Interpolate(limit) +Interpolate(; limit=nothing, r=nothing) = Interpolate(limit, r) function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where T @assert !all(ismissing, data) @@ -51,8 +54,7 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w gap_sz = (next_idx - prev_idx) - 1 if imp.limit === nothing || gap_sz <= imp.limit - diff = data[next_idx] - data[prev_idx] - incr = diff / T(gap_sz + 1) + incr = _calculate_increment(data[prev_idx], data[next_idx], gap_sz + 1, imp.r) val = data[prev_idx] + incr # Iteratively fill in the values @@ -73,8 +75,10 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end -function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where {T<:Union{Signed, Unsigned}} - dataf = _impute!(float(data), imp) - data .= round.(Union{T, Missing}, dataf) - return data -end +# Default cases where no rounding behaviour is specified +_calculate_increment(a, b, n, ::Nothing) = (b - a) / n +_calculate_increment(a::Unsigned, b::Unsigned, n, r::Nothing) = _calculate_increment(Int(a), Int(b), n, r) + +# Pass a rounding mode to `div` +_calculate_increment(a, b, n, r) = div(b - a, n, r) +_calculate_increment(a::Unsigned, b::Unsigned, n, r) = _calculate_increment(Int(a), Int(b), n, r) diff --git a/test/imputors/interp.jl b/test/imputors/interp.jl index 3d30ee9..41b904f 100644 --- a/test/imputors/interp.jl +++ b/test/imputors/interp.jl @@ -90,10 +90,35 @@ @test ismissing(result[1]) @test ismissing(result[20]) - # Test inexact error + # Test with UInt + c = [0x1, missing, 0x3, 0x4] + @test Impute.interp(c) == [0x1, 0x2, 0x3, 0x4] + + # Test reverse case where the increment is negative + @test Impute.interp(reverse(c)) == [0x4, 0x3, 0x2, 0x1] + + # Test inexact error (no rounding mode provided) # https://github.com/invenia/Impute.jl/issues/71 c = [1, missing, 2, 3] @test_throws InexactError Impute.interp(c) + + # Test with UInt + c = [0x1, missing, 0x2, 0x3] + @test_throws InexactError Impute.interp(c) + + # Test reverse case where the increment is negative + @test_throws InexactError Impute.interp(reverse(c)) + + # Test inexact cases with a rounding mode + c = [1, missing, 2, 3] + Impute.interp(c; r=RoundToZero) == [1, 1, 2, 3] + + # Test with UInt + c = [0x1, missing, 0x2, 0x3] + Impute.interp(c; r=RoundNearest) == [0x1, 0x2, 0x2, 0x3] + + # Test reverse case where the increment is negative + Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x1, 0x1] end # TODO Test error cases on non-numeric types