diff --git a/Project.toml b/Project.toml index 2c345ff7..a9262e93 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" @@ -23,6 +24,7 @@ StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" CategoricalArraysArrowExt = "Arrow" CategoricalArraysJSONExt = "JSON" CategoricalArraysRecipesBaseExt = "RecipesBase" +CategoricalArraysStatsBaseExt = "StatsBase" CategoricalArraysSentinelArraysExt = "SentinelArrays" CategoricalArraysStructTypesExt = "StructTypes" @@ -37,6 +39,7 @@ RecipesBase = "1.1" Requires = "1" SentinelArrays = "1" Statistics = "1" +StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34" StructTypes = "1" julia = "1.6" @@ -49,8 +52,9 @@ PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RecipesPipeline = "01d81517-befc-4cb6-b9ec-a95719d0359c" SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StructTypes", "Test"] +test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StatsBase", "StructTypes", "Test"] diff --git a/ext/CategoricalArraysStatsBaseExt.jl b/ext/CategoricalArraysStatsBaseExt.jl new file mode 100644 index 00000000..8cbd5c61 --- /dev/null +++ b/ext/CategoricalArraysStatsBaseExt.jl @@ -0,0 +1,13 @@ +module CategoricalArraysStatsBaseExt + +if isdefined(Base, :get_extension) + import CategoricalArrays: _wquantile + using StatsBase +else + import ..CategoricalArrays: _wquantile + using ..StatsBase +end + +_wquantile(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p) + +end diff --git a/src/CategoricalArrays.jl b/src/CategoricalArrays.jl index f3383645..f44b3c2f 100644 --- a/src/CategoricalArrays.jl +++ b/src/CategoricalArrays.jl @@ -45,6 +45,7 @@ module CategoricalArrays @require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl") @require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl") @require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl") + @require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl") @require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl") end end diff --git a/src/extras.jl b/src/extras.jl index 60f32a64..910c6e46 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -337,11 +337,17 @@ function find_breaks(v::AbstractVector, qs::AbstractVector) return breaks end +# AbstractWeights method is defined in StatsBase extension +# There is no in-place weighted quantile method in StatsBase +_wquantile(x::AbstractArray, w::AbstractVector, p::AbstractVector) = + throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl")) + """ cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:AbstractString},Function}, sigdigits::Integer=3, - allowempty::Bool=false) + allowempty::Bool=false, + weights::Union{AbstractWeights, Nothing}=nothing) Cut a numeric array into `ngroups` quantiles. @@ -373,19 +379,41 @@ quantiles. other than the last one are equal, generating empty intervals; when `true`, duplicate breaks are allowed and the intervals they generate are kept as unused levels (but duplicate labels are not allowed). +* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to used when + computing quantiles (see `quantile` documentation in StatsBase). """ function cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing, sigdigits::Integer=3, - allowempty::Bool=false) + allowempty::Bool=false, + weights::Union{AbstractVector, Nothing}=nothing) ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)")) - sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) - min_x, max_x = first(sorted_x), last(sorted_x) - if (min_x isa Number && isnan(min_x)) || - (max_x isa Number && isnan(max_x)) - throw(ArgumentError("NaN values are not allowed in input vector")) + if weights === nothing + sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x) + min_x, max_x = first(sorted_x), last(sorted_x) + if (min_x isa Number && isnan(min_x)) || + (max_x isa Number && isnan(max_x)) + throw(ArgumentError("NaN values are not allowed in input vector")) + end + qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) + else + if eltype(x) >: Missing + nm_inds = findall(!ismissing, x) + nm_x = view(x, nm_inds) + # TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723) + nm_weights = weights[nm_inds] + else + nm_x = x + nm_weights = weights + end + sorted_x = sort(nm_x) + min_x, max_x = first(sorted_x), last(sorted_x) + if (min_x isa Number && isnan(min_x)) || + (max_x isa Number && isnan(max_x)) + throw(ArgumentError("NaN values are not allowed in input vector")) + end + qs = _wquantile(nm_x, nm_weights, (1:(ngroups-1))/ngroups) end - qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true) breaks = [min_x; find_breaks(sorted_x, qs); max_x] if !allowempty && !allunique(@view breaks[1:end-1]) throw(ArgumentError("cannot compute $ngroups quantiles due to " * diff --git a/test/15_extras.jl b/test/15_extras.jl index 5df7860b..80dc14b7 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -1,6 +1,8 @@ module TestExtras using Test using CategoricalArrays +using StatsBase +using Missings const ≅ = isequal @@ -423,4 +425,26 @@ end end +@testset "cut with weighted quantiles" begin + @test_throws ArgumentError cut(1:3, 3, weights=1:3) + + x = collect(Float64, 1:100) + w = fweights(repeat(1:10, inner=10)) + y = cut(x, 10, weights=w) + @test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10))) + @test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)", + "[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"] + + mx = allowmissing(x) + mx[2] = mx[10] = missing + nm_inds = .!ismissing.(mx) + y = cut(mx, 10, weights=w) + @test levelcode.(y) ≅ levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10))) + @test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)", + "[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"] + + x[5] = NaN + @test_throws ArgumentError cut(x, 3, weights=w) +end + end \ No newline at end of file