Skip to content

Support weighted quantiles in cut #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 21, 2025
Merged
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"

[extensions]
CategoricalArraysArrowExt = "Arrow"
CategoricalArraysJSONExt = "JSON"
CategoricalArraysRecipesBaseExt = "RecipesBase"
CategoricalArraysStatsBaseExt = "StatsBase"
CategoricalArraysSentinelArraysExt = "SentinelArrays"
CategoricalArraysStructTypesExt = "StructTypes"

Expand All @@ -37,6 +39,7 @@ RecipesBase = "1.1"
Requires = "1"
SentinelArrays = "1"
Statistics = "1"
StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
StructTypes = "1"
julia = "1.6"

Expand All @@ -49,8 +52,9 @@ PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecipesPipeline = "01d81517-befc-4cb6-b9ec-a95719d0359c"
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StructTypes", "Test"]
test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StatsBase", "StructTypes", "Test"]
13 changes: 13 additions & 0 deletions ext/CategoricalArraysStatsBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module CategoricalArraysStatsBaseExt

if isdefined(Base, :get_extension)
import CategoricalArrays: _wquantile
using StatsBase
else
import ..CategoricalArrays: _wquantile
using ..StatsBase
end

_wquantile(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p)

end
1 change: 1 addition & 0 deletions src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ module CategoricalArrays
@require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl")
@require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl")
@require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl")
@require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl")
@require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl")
end
end
Expand Down
44 changes: 36 additions & 8 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,17 @@ function find_breaks(v::AbstractVector, qs::AbstractVector)
return breaks
end

# AbstractWeights method is defined in StatsBase extension
# There is no in-place weighted quantile method in StatsBase
_wquantile(x::AbstractArray, w::AbstractVector, p::AbstractVector) =
throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl"))

"""
cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function},
sigdigits::Integer=3,
allowempty::Bool=false)
allowempty::Bool=false,
weights::Union{AbstractWeights, Nothing}=nothing)

Cut a numeric array into `ngroups` quantiles.

Expand Down Expand Up @@ -373,19 +379,41 @@ quantiles.
other than the last one are equal, generating empty intervals;
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
unused levels (but duplicate labels are not allowed).
* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to used when
computing quantiles (see `quantile` documentation in StatsBase).
"""
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
sigdigits::Integer=3,
allowempty::Bool=false)
allowempty::Bool=false,
weights::Union{AbstractVector, Nothing}=nothing)
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
if weights === nothing
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
else
if eltype(x) >: Missing
nm_inds = findall(!ismissing, x)
nm_x = view(x, nm_inds)
# TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723)
nm_weights = weights[nm_inds]
else
nm_x = x
nm_weights = weights
end
sorted_x = sort(nm_x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
qs = _wquantile(nm_x, nm_weights, (1:(ngroups-1))/ngroups)
end
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
if !allowempty && !allunique(@view breaks[1:end-1])
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
Expand Down
24 changes: 24 additions & 0 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module TestExtras
using Test
using CategoricalArrays
using StatsBase
using Missings

const ≅ = isequal

Expand Down Expand Up @@ -423,4 +425,26 @@ end

end

@testset "cut with weighted quantiles" begin
@test_throws ArgumentError cut(1:3, 3, weights=1:3)

x = collect(Float64, 1:100)
w = fweights(repeat(1:10, inner=10))
y = cut(x, 10, weights=w)
@test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10)))
@test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]

mx = allowmissing(x)
mx[2] = mx[10] = missing
nm_inds = .!ismissing.(mx)
y = cut(mx, 10, weights=w)
@test levelcode.(y) ≅ levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10)))
@test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]

x[5] = NaN
@test_throws ArgumentError cut(x, 3, weights=w)
end

end
Loading