Skip to content

Commit

Permalink
Merge pull request #73 from JuliaAI/better-error
Browse files Browse the repository at this point in the history
Fix bug in error for UnivariateFinite constructor
  • Loading branch information
ablaom authored Mar 22, 2024
2 parents 6559cb5 + 5f33ce8 commit 52f8078
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
15 changes: 9 additions & 6 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ _check_augmentable(support, probs) = _check_probs_01(probs) &&
size(probs)[end] + 1 == length(support) ||
throw(err_dim_augmented(support, probs))

err_incompatible_pool(support, classes) = ArgumentError(
"Specified support, $support, not contained in "*
"specified pool, $(levels(classes)). "
)


## AUGMENTING ARRAYS TO MAKE THEM PROBABILITY ARRAYS

Expand Down Expand Up @@ -277,7 +282,7 @@ function _augment_probs(::Val{true},
end

_array_or_scalar(x::Array) = x
_array_or_scalar(x::AbstractArray) = copyto!(similar(Array{eltype(x)}, axes(x)), x)
_array_or_scalar(x::AbstractArray) = copyto!(similar(Array{eltype(x)}, axes(x)), x)
_array_or_scalar(x) = x

## CONSTRUCTORS - FROM DICTIONARY
Expand Down Expand Up @@ -453,14 +458,12 @@ function _UnivariateFinite(support,
_support = classes(v)
else
_classes = classes(pool)
issubset(support, _classes) ||
error("Specified support, $support, not contained in "*
"specified pool, $(levels(classes)). ")
issubset(support, _classes) || throw(err_incompatible_pool(support, _classes))
idxs = getindex.(
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
support
)
_support = _classes[idxs]
_support = _classes[idxs]
end

# calls core method:
Expand Down
12 changes: 9 additions & 3 deletions test/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ import CategoricalDistributions: classes
@test_logs((:warn, r"No "),
UnivariateFinite(['f', 'q', 's'], [0.7, 0.2, 0.1]))

junk = ["F", "Q", "S"]
@test_throws(
CategoricalDistributions.err_incompatible_pool(junk, classes(v)),
UnivariateFinite(junk, [0.1, 0.9], pool=v),
)

end

@testset "array constructors" begin
Expand All @@ -37,7 +43,7 @@ end

UnivariateFinite(supp, probs, pool=missing, augment=true);

# construction from pool and support does not
# construction from pool and support does not
# consist of categorical elements (See issue #34)
v = categorical(["x", "x", "y", "z", "y", "z", "p"])
probs1 = [0.1, 0.2, 0.7]
Expand Down Expand Up @@ -75,7 +81,7 @@ end
v = categorical(['x', 'x', 'y', 'x', 'z', 'w'])
probs_fillarray = FillArrays.Ones(100, 3)
probs_array = ones(100, 3)

probs1_fillarray = FillArrays.Fill(0.2, 100, 2)
probs1_array = fill(0.2, 100, 2)

Expand All @@ -88,7 +94,7 @@ end
u1_from_fillarray = UnivariateFinite(
['x', 'y', 'z'], probs1_fillarray, pool=v, augment=true
)

@test u_from_array.prob_given_ref == u_from_fillarray.prob_given_ref
@test u1_from_array.prob_given_ref == u1_from_fillarray.prob_given_ref

Expand Down

0 comments on commit 52f8078

Please sign in to comment.