Skip to content

Commit

Permalink
functions that only support finite weights now throw errors for non-f…
Browse files Browse the repository at this point in the history
…inites (#914)

* throw errors when only finite weights are supported

* remove extra calls to sum(wv)

* typo

* typo

* add a minimal test for custom weights implementations

* fix new test on 1.0
  • Loading branch information
aplavin committed Jun 17, 2024
1 parent c022f82 commit 87f372c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,9 @@ Optionally specify a random number generator `rng` as the first argument
function sample(rng::AbstractRNG, wv::AbstractWeights)
1 == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
t = rand(rng) * sum(wv)
wsum = sum(wv)
isfinite(wsum) || throw(ArgumentError("only finite weights are supported"))
t = rand(rng) * wsum
n = length(wv)
i = 1
cw = wv[1]
Expand Down Expand Up @@ -654,6 +656,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights,
throw(ArgumentError("output array x must not share memory with input array a"))
1 == firstindex(a) == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
length(wv) == length(a) || throw(DimensionMismatch("Inconsistent lengths."))

# create alias table
Expand Down Expand Up @@ -688,13 +691,14 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
wsum = sum(wv)
isfinite(wsum) || throw(ArgumentError("only finite weights are supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
k = length(x)

w = Vector{Float64}(undef, n)
copyto!(w, wv)
wsum = sum(wv)

for i = 1:k
u = rand(rng) * wsum
Expand Down Expand Up @@ -734,6 +738,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
Expand Down Expand Up @@ -775,6 +780,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
Expand Down Expand Up @@ -848,6 +854,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
Expand Down
2 changes: 2 additions & 0 deletions src/scalarstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ end
# Weighted mode of arbitrary vectors of values
function mode(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
length(a) == length(wv) ||
throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))"))

Expand All @@ -184,6 +185,7 @@ end

function modes(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real
isempty(a) && throw(ArgumentError("mode is not defined for empty collections"))
isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported"))
length(a) == length(wv) ||
throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))"))

Expand Down
1 change: 1 addition & 0 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ function quantile(v::AbstractVector{<:Real}{V}, w::AbstractWeights{W}, p::Abstra
# checks
isempty(v) && throw(ArgumentError("quantile of an empty array is undefined"))
isempty(p) && throw(ArgumentError("empty quantile array"))
isfinite(sum(w)) || throw(ArgumentError("only finite weights are supported"))
all(x -> 0 <= x <= 1, p) || throw(ArgumentError("input probability out of [0,1] range"))

w.sum == 0 && throw(ArgumentError("weight vector cannot sum to zero"))
Expand Down
16 changes: 16 additions & 0 deletions test/weights.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
using StatsBase
using LinearAlgebra, Random, SparseArrays, Test


# minimal custom weights type for tests below
struct MyWeights <: AbstractWeights{Float64, Float64, Vector{Float64}}
values::Vector{Float64}
sum::Float64
end
MyWeights(values) = MyWeights(values, sum(values))


@testset "StatsBase.Weights" begin
weight_funcs = (weights, aweights, fweights, pweights)

Expand Down Expand Up @@ -610,4 +619,11 @@ end
end
end

@testset "custom weight types" begin
@test mean([1, 2, 3], MyWeights([1, 4, 10])) 2.6
@test mean([1, 2, 3], MyWeights([NaN, 4, 10])) |> isnan
@test mode([1, 2, 3], MyWeights([1, 4, 10])) == 3
@test_throws ArgumentError mode([1, 2, 3], MyWeights([NaN, 4, 10]))
end

end # @testset StatsBase.Weights

0 comments on commit 87f372c

Please sign in to comment.