Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle interp for integers by casting and rounding (#71) #142

Merged
merged 9 commits into from
Apr 19, 2024
58 changes: 39 additions & 19 deletions src/imputors/interp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Interpolate(; limit=nothing)
Interpolate(; limit=nothing, r=nothing)

Performs linear interpolation between the nearest values in an vector.
The current implementation is univariate, so each variable in a table or matrix will
Expand All @@ -11,6 +11,8 @@ that all missing values will be imputed.

# Keyword Arguments
* `limit::Union{UInt, Nothing}`: Optionally limit the gap sizes that can be interpolated.
* `r::Union{RoundingMode, Nothing}`: Optionally specify a rounding mode.
Avoids `InexactError`s when interpolating over integers.

# Example
```jldoctest
Expand All @@ -34,35 +36,25 @@ julia> impute(M, Interpolate(; limit=2); dims=:rows)
"""
struct Interpolate <: Imputor
limit::Union{UInt, Nothing}
r::Union{RoundingMode, Nothing}
end

Interpolate(; limit=nothing) = Interpolate(limit)
Interpolate(; limit=nothing, r=nothing) = Interpolate(limit, r)

function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where T
@assert !all(ismissing, data)
i = findfirst(!ismissing, data) + 1

while i < lastindex(data)
if ismissing(data[i])
prev_idx = i - 1
next_idx = findnext(!ismissing, data, i + 1)

if next_idx !== nothing
gap_sz = (next_idx - prev_idx) - 1

if imp.limit === nothing || gap_sz <= imp.limit
diff = data[next_idx] - data[prev_idx]
incr = diff / T(gap_sz + 1)
val = data[prev_idx] + incr

# Iteratively fill in the values
for j in i:(next_idx - 1)
data[j] = val
val += incr
end
j = _findnext(data, i + 1)

if j !== nothing
if imp.limit === nothing || j - i + 1 <= imp.limit
_interpolate!(data, i:j, data[i - 1], data[j + 1], imp.r)
end

i = next_idx
i = j + 1
else
break
end
Expand All @@ -72,3 +64,31 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w

return data
end

# Our kernel function used to avoid type instability issues.
# https://docs.julialang.org/en/v1/manual/performance-tips/#kernel-functions
function _interpolate!(data, indices, prev, next, r)
incr = _calculate_increment(prev, next, length(indices) + 1)

for (i, k) in enumerate(indices)
data[k] = _calculate_value(prev, incr, i, r)
end
end

# Utility function for finding the last index within a missing data block
function _findnext(data, i)
j = findnext(!ismissing, data, i)
j === nothing && return j
return j - 1
end

# Calculates the increment for interpolation
_calculate_increment(a, b, n) = (b - a) / n
# Special case for avoiding integer overflow
_calculate_increment(a::T, b::T, n) where {T<:Unsigned} = _calculate_increment(Int(a), Int(b), n)

# Calculates the interpolated value for a given iteration i
# Default case of simply prev + incr * i
_calculate_value(prev, incr, i, r) = prev + incr * i
# Special case for rounding integers
_calculate_value(prev::T, incr, i, r::RoundingMode) where {T<:Integer} = round(T, prev + incr * i, r)
51 changes: 50 additions & 1 deletion test/imputors/interp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,59 @@
@test ismissing(result[1])
@test ismissing(result[20])

# Test inexact error
# Test with UInt
c = [0x1, missing, 0x3, 0x4]
@test Impute.interp(c) == [0x1, 0x2, 0x3, 0x4]

# Test reverse case where the increment is negative
@test Impute.interp(reverse(c)) == [0x4, 0x3, 0x2, 0x1]

# Test inexact error (no rounding mode provided)
# https://github.com/invenia/Impute.jl/issues/71
c = [1, missing, 2, 3]
@test_throws InexactError Impute.interp(c)

# Test with UInt
c = [0x1, missing, 0x2, 0x3]
@test_throws InexactError Impute.interp(c)

# Test reverse case where the increment is negative
@test_throws InexactError Impute.interp(reverse(c))

# Test inexact cases with a rounding mode
c = [1, missing, 2, 3]
@test Impute.interp(c; r=RoundToZero) == [1, 1, 2, 3]

# Test with UInt
c = [0x1, missing, 0x2, 0x3]
@test Impute.interp(c; r=RoundNearest) == [0x1, 0x2, 0x2, 0x3]

# Test reverse case where the increment is negative
@test Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x2, 0x1]

# Test rounding doesn't cause values to exceed endpoint values
@test Impute.interp([1, missing, missing, 2]; r=RoundUp) == [1, 2, 2, 2]
@test Impute.interp([2, missing, missing, 1]; r=RoundUp) == [2, 2, 2, 1]
@test Impute.interp([1, missing, missing, 0]; r=RoundDown) == [1, 0, 0, 0]
@test Impute.interp([0x1, missing, missing, 0x0]; r=RoundDown) == [0x1, 0x0, 0x0, 0x0]

# Test long gaps (above .5 increment)
@test Impute.interp([2, fill(missing, 10)..., 8]; r=RoundNearest) == [2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8]
@test Impute.interp([0x2, fill(missing, 10)..., 0x8]; r=RoundNearest) == [0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8]
@test Impute.interp([8, fill(missing, 10)..., 2]; r=RoundNearest) == [8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2]
@test Impute.interp([0x8, fill(missing, 10)..., 0x2]; r=RoundNearest) == [0x8, 0x7, 0x7, 0x6, 0x6, 0x5, 0x5, 0x4, 0x4, 0x3, 0x3, 0x2]

# Test long gaps (at .5 increment)
@test Impute.interp([2, fill(missing, 11)..., 8]; r=RoundNearest) == [2, 2, 3, 4, 4, 4, 5, 6, 6, 6, 7, 8, 8]
@test Impute.interp([0x2, fill(missing, 11)..., 0x8]; r=RoundNearest) == [0x2, 0x2, 0x3, 0x4, 0x4, 0x4, 0x5, 0x6, 0x6, 0x6, 0x7, 0x8, 0x8]
@test Impute.interp([8, fill(missing, 11)..., 2]; r=RoundNearest) == [8, 8, 7, 6, 6, 6, 5, 4, 4, 4, 3, 2, 2]
@test Impute.interp([0x8, fill(missing, 11)..., 0x2]; r=RoundNearest) == [0x8, 0x8, 0x7, 0x6, 0x6, 0x6, 0x5, 0x4, 0x4, 0x4, 0x3, 0x2, 0x2]

# Test long gaps (below .5 increment)
@test Impute.interp([2, fill(missing, 12)..., 8]; r=RoundNearest) == [2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8]
@test Impute.interp([0x2, fill(missing, 12)..., 0x8]; r=RoundNearest) == [0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8]
@test Impute.interp([8, fill(missing, 12)..., 2]; r=RoundNearest) == [8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2]
@test Impute.interp([0x8, fill(missing, 12)..., 0x2]; r=RoundNearest) == [0x8, 0x8, 0x7, 0x7, 0x6, 0x6, 0x5, 0x5, 0x4, 0x4, 0x3, 0x3, 0x2, 0x2]
end

# TODO Test error cases on non-numeric types
Expand Down
Loading