From ee9856d60511c8b6b828f733865583b5f18b9691 Mon Sep 17 00:00:00 2001 From: kpa28 Date: Mon, 4 Mar 2024 21:12:57 -0800 Subject: [PATCH 1/9] handle interp for integers by casting and rounding (#71) --- src/imputors/interp.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index 2963c36..9329021 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -72,3 +72,9 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end + +function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where {T<:Union{Signed, Unsigned}} + dataf = _impute!(float(data), imp) + data .= round.(Union{T, Missing}, dataf) + return data +end From eb15e82a2bfb2c162023b718c62f36661f5d7c66 Mon Sep 17 00:00:00 2001 From: rofinn Date: Mon, 8 Apr 2024 17:13:36 -0700 Subject: [PATCH 2/9] Add a rounding mode and dispatch when calculating increment. 1. We add a `r::Union{RoundingMode, Nothing}=nothing` to `Interpolate`, avoiding any breaking changes. 2. Added unit tests for unsigned integers and cases with and without the RoundingMode. 3. Increment calculation: - Still using `(next - prev) / n` in the default case - If a rounding mode is specified then `div` is used with the corresponding mode (avoids `InexactError`) - If we're given unsigned ints then we convert them to integers for the increment calculation (avoids integer overflow) --- src/imputors/interp.jl | 22 +++++++++++++--------- test/imputors/interp.jl | 27 ++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index 9329021..cccdc78 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -1,5 +1,5 @@ """ - Interpolate(; limit=nothing) + Interpolate(; limit=nothing, r=nothing) Performs linear interpolation between the nearest values in an vector. The current implementation is univariate, so each variable in a table or matrix will @@ -11,6 +11,8 @@ that all missing values will be imputed. # Keyword Arguments * `limit::Union{UInt, Nothing}`: Optionally limit the gap sizes that can be interpolated. +* `r::Union{RoundingMode, Nothing}`: Optionally specify a rounding mode. + Avoids `InexactError`s when interpolating over integers. # Example ```jldoctest @@ -34,9 +36,10 @@ julia> impute(M, Interpolate(; limit=2); dims=:rows) """ struct Interpolate <: Imputor limit::Union{UInt, Nothing} + r::Union{RoundingMode, Nothing} end -Interpolate(; limit=nothing) = Interpolate(limit) +Interpolate(; limit=nothing, r=nothing) = Interpolate(limit, r) function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where T @assert !all(ismissing, data) @@ -51,8 +54,7 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w gap_sz = (next_idx - prev_idx) - 1 if imp.limit === nothing || gap_sz <= imp.limit - diff = data[next_idx] - data[prev_idx] - incr = diff / T(gap_sz + 1) + incr = _calculate_increment(data[prev_idx], data[next_idx], gap_sz + 1, imp.r) val = data[prev_idx] + incr # Iteratively fill in the values @@ -73,8 +75,10 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end -function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) where {T<:Union{Signed, Unsigned}} - dataf = _impute!(float(data), imp) - data .= round.(Union{T, Missing}, dataf) - return data -end +# Default cases where no rounding behaviour is specified +_calculate_increment(a, b, n, ::Nothing) = (b - a) / n +_calculate_increment(a::Unsigned, b::Unsigned, n, r::Nothing) = _calculate_increment(Int(a), Int(b), n, r) + +# Pass a rounding mode to `div` +_calculate_increment(a, b, n, r) = div(b - a, n, r) +_calculate_increment(a::Unsigned, b::Unsigned, n, r) = _calculate_increment(Int(a), Int(b), n, r) diff --git a/test/imputors/interp.jl b/test/imputors/interp.jl index 3d30ee9..41b904f 100644 --- a/test/imputors/interp.jl +++ b/test/imputors/interp.jl @@ -90,10 +90,35 @@ @test ismissing(result[1]) @test ismissing(result[20]) - # Test inexact error + # Test with UInt + c = [0x1, missing, 0x3, 0x4] + @test Impute.interp(c) == [0x1, 0x2, 0x3, 0x4] + + # Test reverse case where the increment is negative + @test Impute.interp(reverse(c)) == [0x4, 0x3, 0x2, 0x1] + + # Test inexact error (no rounding mode provided) # https://github.com/invenia/Impute.jl/issues/71 c = [1, missing, 2, 3] @test_throws InexactError Impute.interp(c) + + # Test with UInt + c = [0x1, missing, 0x2, 0x3] + @test_throws InexactError Impute.interp(c) + + # Test reverse case where the increment is negative + @test_throws InexactError Impute.interp(reverse(c)) + + # Test inexact cases with a rounding mode + c = [1, missing, 2, 3] + Impute.interp(c; r=RoundToZero) == [1, 1, 2, 3] + + # Test with UInt + c = [0x1, missing, 0x2, 0x3] + Impute.interp(c; r=RoundNearest) == [0x1, 0x2, 0x2, 0x3] + + # Test reverse case where the increment is negative + Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x1, 0x1] end # TODO Test error cases on non-numeric types From 7493e619f322d3cbf5406edefcb18bbef6c14eca Mon Sep 17 00:00:00 2001 From: rofinn Date: Fri, 12 Apr 2024 18:28:40 -0700 Subject: [PATCH 3/9] Add interp test for exceeding endpoint values. --- test/imputors/interp.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/imputors/interp.jl b/test/imputors/interp.jl index 41b904f..edd3d69 100644 --- a/test/imputors/interp.jl +++ b/test/imputors/interp.jl @@ -111,14 +111,20 @@ # Test inexact cases with a rounding mode c = [1, missing, 2, 3] - Impute.interp(c; r=RoundToZero) == [1, 1, 2, 3] + @test Impute.interp(c; r=RoundToZero) == [1, 1, 2, 3] # Test with UInt c = [0x1, missing, 0x2, 0x3] - Impute.interp(c; r=RoundNearest) == [0x1, 0x2, 0x2, 0x3] + @test Impute.interp(c; r=RoundNearest) == [0x1, 0x1, 0x2, 0x3] # Test reverse case where the increment is negative - Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x1, 0x1] + @test Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x2, 0x1] + + # Test rounding doesn't cause values to exceed endpoint values + @test Impute.interp([1, missing, missing, 2]; r=RoundUp) == [1, 2, 3, 2] + @test Impute.interp([2, missing, missing, 1]; r=RoundUp) == [2, 2, 2, 1] + @test Impute.interp([1, missing, missing, 0]; r=RoundDown) == [1, 0, -1, 0] + @test_throws InexactError Impute.interp([0x1, missing, missing, 0x0]; r=RoundDown) end # TODO Test error cases on non-numeric types From 737bda6c331880497c72d4fff95617e74da845a4 Mon Sep 17 00:00:00 2001 From: rofinn Date: Fri, 12 Apr 2024 18:37:44 -0700 Subject: [PATCH 4/9] Fix tests by clamping the inserted value. --- src/imputors/interp.jl | 12 +++++++++--- test/imputors/interp.jl | 6 +++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index cccdc78..b38e6d7 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -54,12 +54,18 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w gap_sz = (next_idx - prev_idx) - 1 if imp.limit === nothing || gap_sz <= imp.limit - incr = _calculate_increment(data[prev_idx], data[next_idx], gap_sz + 1, imp.r) - val = data[prev_idx] + incr + prev = data[prev_idx] + next = data[next_idx] + incr = _calculate_increment(prev, next, gap_sz + 1, imp.r) + val = prev + incr # Iteratively fill in the values + # Determine hi and lo values for clamping in the loop + # According to @benchmark calling extrema with a tuple has the same + # performance as calling min/max individually. + lo, hi = extrema((prev, next)) for j in i:(next_idx - 1) - data[j] = val + data[j] = clamp(val, lo, hi) val += incr end end diff --git a/test/imputors/interp.jl b/test/imputors/interp.jl index edd3d69..60a8b32 100644 --- a/test/imputors/interp.jl +++ b/test/imputors/interp.jl @@ -121,10 +121,10 @@ @test Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x2, 0x1] # Test rounding doesn't cause values to exceed endpoint values - @test Impute.interp([1, missing, missing, 2]; r=RoundUp) == [1, 2, 3, 2] + @test Impute.interp([1, missing, missing, 2]; r=RoundUp) == [1, 2, 2, 2] @test Impute.interp([2, missing, missing, 1]; r=RoundUp) == [2, 2, 2, 1] - @test Impute.interp([1, missing, missing, 0]; r=RoundDown) == [1, 0, -1, 0] - @test_throws InexactError Impute.interp([0x1, missing, missing, 0x0]; r=RoundDown) + @test Impute.interp([1, missing, missing, 0]; r=RoundDown) == [1, 0, 0, 0] + @test Impute.interp([0x1, missing, missing, 0x0]; r=RoundDown) == [0x1, 0x0, 0x0, 0x0] end # TODO Test error cases on non-numeric types From 2f7fc0abf12e2ecb956e0e30f1b61e835f1c1761 Mon Sep 17 00:00:00 2001 From: kpa28 Date: Tue, 16 Apr 2024 20:34:37 -0700 Subject: [PATCH 5/9] interpolate integers in floating point using generators to avoid allocations + some tests --- src/imputors/interp.jl | 59 +++++++++++++++++++++++++++-------------- test/imputors/interp.jl | 20 +++++++++++++- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index b38e6d7..01a29be 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -54,20 +54,8 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w gap_sz = (next_idx - prev_idx) - 1 if imp.limit === nothing || gap_sz <= imp.limit - prev = data[prev_idx] - next = data[next_idx] - incr = _calculate_increment(prev, next, gap_sz + 1, imp.r) - val = prev + incr - - # Iteratively fill in the values - # Determine hi and lo values for clamping in the loop - # According to @benchmark calling extrema with a tuple has the same - # performance as calling min/max individually. - lo, hi = extrema((prev, next)) - for j in i:(next_idx - 1) - data[j] = clamp(val, lo, hi) - val += incr - end + gen = _gen_interp(data[prev_idx], data[next_idx], gap_sz+1, imp.r) + _gen_set!(data, prev_idx, gen) end i = next_idx @@ -81,10 +69,41 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end -# Default cases where no rounding behaviour is specified -_calculate_increment(a, b, n, ::Nothing) = (b - a) / n -_calculate_increment(a::Unsigned, b::Unsigned, n, r::Nothing) = _calculate_increment(Int(a), Int(b), n, r) +""" +Set a vector slice over the values of a generator, starting from `after+1` +""" +function _gen_set!(v::AbstractVector, after::Integer, gen) + for (i, val) in enumerate(gen) + v[after+i] = val + end +end + +""" +Return generator over interpolated values. +""" +function _gen_interp(a, b, n, ::Nothing) + inc = _calculate_increment(a, b, n) + (a + inc*i for i=1:n) +end + +function _gen_interp(a, b, n, r::RoundingMode) + inc = _calculate_increment(a, b, n) + (round(a + inc*i, r) for i=1:n) +end + +function _gen_interp(a::T, b::T, n, ::Nothing) where {T<:Integer} + inc = _calculate_increment(a, b, n) + (convert(T, a + inc*i) for i=1:n) +end + +function _gen_interp(a::T, b::T, n, r::RoundingMode) where {T<:Integer} + inc = _calculate_increment(a, b, n) + (round(T, a + inc*i, r) for i=1:n) +end + +_calculate_increment(a, b, n) = (b - a) / n + +function _calculate_increment(a::T, b::T, n) where {T<:Integer} + _calculate_increment(float(a), float(b), n) +end -# Pass a rounding mode to `div` -_calculate_increment(a, b, n, r) = div(b - a, n, r) -_calculate_increment(a::Unsigned, b::Unsigned, n, r) = _calculate_increment(Int(a), Int(b), n, r) diff --git a/test/imputors/interp.jl b/test/imputors/interp.jl index 60a8b32..d849166 100644 --- a/test/imputors/interp.jl +++ b/test/imputors/interp.jl @@ -115,7 +115,7 @@ # Test with UInt c = [0x1, missing, 0x2, 0x3] - @test Impute.interp(c; r=RoundNearest) == [0x1, 0x1, 0x2, 0x3] + @test Impute.interp(c; r=RoundNearest) == [0x1, 0x2, 0x2, 0x3] # Test reverse case where the increment is negative @test Impute.interp(reverse(c); r=RoundUp) == [0x3, 0x2, 0x2, 0x1] @@ -125,6 +125,24 @@ @test Impute.interp([2, missing, missing, 1]; r=RoundUp) == [2, 2, 2, 1] @test Impute.interp([1, missing, missing, 0]; r=RoundDown) == [1, 0, 0, 0] @test Impute.interp([0x1, missing, missing, 0x0]; r=RoundDown) == [0x1, 0x0, 0x0, 0x0] + + # Test long gaps (above .5 increment) + @test Impute.interp([2, fill(missing, 10)..., 8]; r=RoundNearest) == [2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8] + @test Impute.interp([0x2, fill(missing, 10)..., 0x8]; r=RoundNearest) == [0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8] + @test Impute.interp([8, fill(missing, 10)..., 2]; r=RoundNearest) == [8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2] + @test Impute.interp([0x8, fill(missing, 10)..., 0x2]; r=RoundNearest) == [0x8, 0x7, 0x7, 0x6, 0x6, 0x5, 0x5, 0x4, 0x4, 0x3, 0x3, 0x2] + + # Test long gaps (at .5 increment) + @test Impute.interp([2, fill(missing, 11)..., 8]; r=RoundNearest) == [2, 2, 3, 4, 4, 4, 5, 6, 6, 6, 7, 8, 8] + @test Impute.interp([0x2, fill(missing, 11)..., 0x8]; r=RoundNearest) == [0x2, 0x2, 0x3, 0x4, 0x4, 0x4, 0x5, 0x6, 0x6, 0x6, 0x7, 0x8, 0x8] + @test Impute.interp([8, fill(missing, 11)..., 2]; r=RoundNearest) == [8, 8, 7, 6, 6, 6, 5, 4, 4, 4, 3, 2, 2] + @test Impute.interp([0x8, fill(missing, 11)..., 0x2]; r=RoundNearest) == [0x8, 0x8, 0x7, 0x6, 0x6, 0x6, 0x5, 0x4, 0x4, 0x4, 0x3, 0x2, 0x2] + + # Test long gaps (below .5 increment) + @test Impute.interp([2, fill(missing, 12)..., 8]; r=RoundNearest) == [2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8] + @test Impute.interp([0x2, fill(missing, 12)..., 0x8]; r=RoundNearest) == [0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8] + @test Impute.interp([8, fill(missing, 12)..., 2]; r=RoundNearest) == [8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2] + @test Impute.interp([0x8, fill(missing, 12)..., 0x2]; r=RoundNearest) == [0x8, 0x8, 0x7, 0x7, 0x6, 0x6, 0x5, 0x5, 0x4, 0x4, 0x3, 0x3, 0x2, 0x2] end # TODO Test error cases on non-numeric types From ddc8a9c3c408974ad4ee32e478b252b24ae1a873 Mon Sep 17 00:00:00 2001 From: kpa28 Date: Tue, 16 Apr 2024 20:47:01 -0700 Subject: [PATCH 6/9] ignore RoundingMode for non-integers like floats --- src/imputors/interp.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index 01a29be..279c4bd 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -86,10 +86,7 @@ function _gen_interp(a, b, n, ::Nothing) (a + inc*i for i=1:n) end -function _gen_interp(a, b, n, r::RoundingMode) - inc = _calculate_increment(a, b, n) - (round(a + inc*i, r) for i=1:n) -end +_gen_interp(a, b, n, r::RoundingMode) = _gen_interp(a, b, n, nothing) function _gen_interp(a::T, b::T, n, ::Nothing) where {T<:Integer} inc = _calculate_increment(a, b, n) From 26ef7e0292f141ad7e499d7711ddea7c953bcd7a Mon Sep 17 00:00:00 2001 From: rofinn Date: Tue, 16 Apr 2024 22:27:05 -0700 Subject: [PATCH 7/9] Only dispatch on Unsigned increment calculation and rounding non-integer values to integers on insertion. --- src/imputors/interp.jl | 54 +++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index 279c4bd..b1d0e64 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -54,8 +54,16 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w gap_sz = (next_idx - prev_idx) - 1 if imp.limit === nothing || gap_sz <= imp.limit - gen = _gen_interp(data[prev_idx], data[next_idx], gap_sz+1, imp.r) - _gen_set!(data, prev_idx, gen) + prev = data[prev_idx] + next = data[next_idx] + incr = _calculate_increment(prev, next, gap_sz + 1) + val = prev + incr + + # Iteratively fill in the values + for j in i:(next_idx - 1) + _setindex!(data, val, j, imp.r) + val += incr + end end i = next_idx @@ -69,38 +77,14 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end -""" -Set a vector slice over the values of a generator, starting from `after+1` -""" -function _gen_set!(v::AbstractVector, after::Integer, gen) - for (i, val) in enumerate(gen) - v[after+i] = val - end -end - -""" -Return generator over interpolated values. -""" -function _gen_interp(a, b, n, ::Nothing) - inc = _calculate_increment(a, b, n) - (a + inc*i for i=1:n) -end - -_gen_interp(a, b, n, r::RoundingMode) = _gen_interp(a, b, n, nothing) - -function _gen_interp(a::T, b::T, n, ::Nothing) where {T<:Integer} - inc = _calculate_increment(a, b, n) - (convert(T, a + inc*i) for i=1:n) -end - -function _gen_interp(a::T, b::T, n, r::RoundingMode) where {T<:Integer} - inc = _calculate_increment(a, b, n) - (round(T, a + inc*i, r) for i=1:n) -end - +# Calculating an increment value _calculate_increment(a, b, n) = (b - a) / n - -function _calculate_increment(a::T, b::T, n) where {T<:Integer} - _calculate_increment(float(a), float(b), n) +# Special case for unsigned to avoid integer overflow +_calculate_increment(a::T, b::T, n) where {T<:Unsigned} = _calculate_increment(Int(a), Int(b), n) + +# For handling rounding on insertions +_setindex!(data, val, i, r) = setindex!(data, val, i) +# Special case for rounding non-integer values on insertion into integer arrays. +function _setindex!(data::AbstractVector{<:Union{T, Missing}}, val::S, i, r::RoundingMode) where {T <: Integer, S} + T === S ? setindex!(data, val, i) : setindex!(data, round(T, val, r), i) end - From a2f8d07f741533ad4b29a6a9f1e6d2655598baf2 Mon Sep 17 00:00:00 2001 From: kpa28 Date: Wed, 17 Apr 2024 00:21:59 -0700 Subject: [PATCH 8/9] refactor of interpolate integers in floating point using generators... --- src/imputors/interp.jl | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index b1d0e64..f276c58 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -54,16 +54,9 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w gap_sz = (next_idx - prev_idx) - 1 if imp.limit === nothing || gap_sz <= imp.limit - prev = data[prev_idx] - next = data[next_idx] - incr = _calculate_increment(prev, next, gap_sz + 1) - val = prev + incr - - # Iteratively fill in the values - for j in i:(next_idx - 1) - _setindex!(data, val, j, imp.r) - val += incr - end + inc = _calculate_increment(data[prev_idx], data[next_idx], gap_sz+1) + gen = _gen_interp(data[prev_idx], inc, gap_sz+1, imp.r) + _gen_set!(data, prev_idx, gen) end i = next_idx @@ -77,14 +70,16 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end -# Calculating an increment value +# sets vector slice via a generator (faster) +function _gen_set!(v::AbstractVector, after::Integer, gen) + for (i, val) in enumerate(gen) + v[after+i] = val + end +end + +# generator of interpolated values +_gen_interp(a, inc, n, r) = (a + inc*i for i=1:n) +_gen_interp(a::T, inc, n, r::RoundingMode) where {T<:Integer} = (round(T, a + inc*i, r) for i=1:n) + _calculate_increment(a, b, n) = (b - a) / n -# Special case for unsigned to avoid integer overflow _calculate_increment(a::T, b::T, n) where {T<:Unsigned} = _calculate_increment(Int(a), Int(b), n) - -# For handling rounding on insertions -_setindex!(data, val, i, r) = setindex!(data, val, i) -# Special case for rounding non-integer values on insertion into integer arrays. -function _setindex!(data::AbstractVector{<:Union{T, Missing}}, val::S, i, r::RoundingMode) where {T <: Integer, S} - T === S ? setindex!(data, val, i) : setindex!(data, round(T, val, r), i) -end From 486816ae3e44eb05a5964bcdcdae0bc0aa6393bd Mon Sep 17 00:00:00 2001 From: rofinn Date: Thu, 18 Apr 2024 17:40:12 -0700 Subject: [PATCH 9/9] Drop the generator 1. We just need a kernel function barrier to avoid type inner loop type instability with integers 2. The generator caused overhead for the default floats case 3. Simplified some of the indexing logic since the kernel function can handle most of it. 4. Add some comments about why all these extra internal functions exist. --- src/imputors/interp.jl | 43 +++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index f276c58..44dc1ca 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -47,19 +47,14 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w while i < lastindex(data) if ismissing(data[i]) - prev_idx = i - 1 - next_idx = findnext(!ismissing, data, i + 1) + j = _findnext(data, i + 1) - if next_idx !== nothing - gap_sz = (next_idx - prev_idx) - 1 - - if imp.limit === nothing || gap_sz <= imp.limit - inc = _calculate_increment(data[prev_idx], data[next_idx], gap_sz+1) - gen = _gen_interp(data[prev_idx], inc, gap_sz+1, imp.r) - _gen_set!(data, prev_idx, gen) + if j !== nothing + if imp.limit === nothing || j - i + 1 <= imp.limit + _interpolate!(data, i:j, data[i - 1], data[j + 1], imp.r) end - i = next_idx + i = j + 1 else break end @@ -70,16 +65,30 @@ function _impute!(data::AbstractVector{<:Union{T, Missing}}, imp::Interpolate) w return data end -# sets vector slice via a generator (faster) -function _gen_set!(v::AbstractVector, after::Integer, gen) - for (i, val) in enumerate(gen) - v[after+i] = val +# Our kernel function used to avoid type instability issues. +# https://docs.julialang.org/en/v1/manual/performance-tips/#kernel-functions +function _interpolate!(data, indices, prev, next, r) + incr = _calculate_increment(prev, next, length(indices) + 1) + + for (i, k) in enumerate(indices) + data[k] = _calculate_value(prev, incr, i, r) end end -# generator of interpolated values -_gen_interp(a, inc, n, r) = (a + inc*i for i=1:n) -_gen_interp(a::T, inc, n, r::RoundingMode) where {T<:Integer} = (round(T, a + inc*i, r) for i=1:n) +# Utility function for finding the last index within a missing data block +function _findnext(data, i) + j = findnext(!ismissing, data, i) + j === nothing && return j + return j - 1 +end +# Calculates the increment for interpolation _calculate_increment(a, b, n) = (b - a) / n +# Special case for avoiding integer overflow _calculate_increment(a::T, b::T, n) where {T<:Unsigned} = _calculate_increment(Int(a), Int(b), n) + +# Calculates the interpolated value for a given iteration i +# Default case of simply prev + incr * i +_calculate_value(prev, incr, i, r) = prev + incr * i +# Special case for rounding integers +_calculate_value(prev::T, incr, i, r::RoundingMode) where {T<:Integer} = round(T, prev + incr * i, r)