Skip to content

Commit

Permalink
Add a rounding mode and dispatch when calculating increment.
Browse files Browse the repository at this point in the history
1. We add a `r::Union{RoundingMode, Nothing}=nothing` to `Interpolate`, avoiding any breaking changes.
2. Added unit tests for unsigned integers and cases with and without the RoundingMode.
3. Increment calculation:
  - Still using `(next - prev) / n` in the default case
  - If a rounding mode is specified then `div` is used with the corresponding mode (avoids `InexactError`)
  - If we're given unsigned ints then we convert them to integers for the increment calculation (avoids integer overflow)
  • Loading branch information
rofinn committed Apr 9, 2024
1 parent ee9856d commit eb15e82
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
22 changes: 13 additions & 9 deletions src/imputors/interp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Interpolate(; limit=nothing)
Interpolate(; limit=nothing, r=nothing)
Performs linear interpolation between the nearest values in an vector.
The current implementation is univariate, so each variable in a table or matrix will
Expand All @@ -11,6 +11,8 @@ that all missing values will be imputed.
# Keyword Arguments
* `limit::Union{UInt, Nothing}`: Optionally limit the gap sizes that can be interpolated.
* `r::Union{RoundingMode, Nothing}`: Optionally specify a rounding mode.
Avoids `InexactError`s when interpolating over integers.
# Example
```jldoctest
Expand All @@ -34,9 +36,10 @@ julia> impute(M, Interpolate(; limit=2); dims=:rows)
"""
struct Interpolate <: Imputor
limit::Union{UInt, Nothing}
r::Union{RoundingMode, Nothing}
end

Interpolate(; limit=nothing) = Interpolate(limit)
Interpolate(; limit=nothing, r=nothing) = Interpolate(limit, r)

function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where T
@assert !all(ismissing, data)
Expand All @@ -51,8 +54,7 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w
gap_sz = (next_idx - prev_idx) - 1

if imp.limit === nothing || gap_sz <= imp.limit
diff = data[next_idx] - data[prev_idx]
incr = diff / T(gap_sz + 1)
incr = _calculate_increment(data[prev_idx], data[next_idx], gap_sz + 1, imp.r)
val = data[prev_idx] + incr

# Iteratively fill in the values
Expand All @@ -73,8 +75,10 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w
return data
end

function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where {T<:Union{Signed, Unsigned}}
dataf = _impute!(float(data), imp)
data .= round.(Union{T, Missing}, dataf)
return data
end
# Default cases where no rounding behaviour is specified
_calculate_increment(a, b, n, ::Nothing) = (b - a) / n
_calculate_increment(a::Unsigned, b::Unsigned, n, r::Nothing) = _calculate_increment(Int(a), Int(b), n, r)

# Pass a rounding mode to `div`
_calculate_increment(a, b, n, r) = div(b - a, n, r)
_calculate_increment(a::Unsigned, b::Unsigned, n, r) = _calculate_increment(Int(a), Int(b), n, r)
27 changes: 26 additions & 1 deletion test/imputors/interp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,35 @@
@test ismissing(result[1])
@test ismissing(result[20])

# Test inexact error
# Test with UInt
c = [0x1, missing, 0x3, 0x4]
@test Impute.interp(c) == [0x1, 0x2, 0x3, 0x4]

# Test reverse case where the increment is negative
@test Impute.interp(reverse(c)) == [0x4, 0x3, 0x2, 0x1]

# Test inexact error (no rounding mode provided)
# https://github.com/invenia/Impute.jl/issues/71
c = [1, missing, 2, 3]
@test_throws InexactError Impute.interp(c)

# Test with UInt
c = [0x1, missing, 0x2, 0x3]
@test_throws InexactError Impute.interp(c)

# Test reverse case where the increment is negative
@test_throws InexactError Impute.interp(reverse(c))

# Test inexact cases with a rounding mode
c = [1, missing, 2, 3]
Impute.interp(c; r=RoundToZero) == [1, 1, 2, 3]

# Test with UInt
c = [0x1, missing, 0x2, 0x3]
Impute.interp(c; r=RoundNearest) == [0x1, 0x2, 0x2, 0x3]

# Test reverse case where the increment is negative
Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x1, 0x1]
end

# TODO Test error cases on non-numeric types
Expand Down

0 comments on commit eb15e82

Please sign in to comment.