Skip to content

Commit

Permalink
Merge pull request #74 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.1.15 release
  • Loading branch information
ablaom authored Mar 22, 2024
2 parents 79ee322 + 8011efb commit daa0879
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CategoricalDistributions"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.14"
version = "0.1.15"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
15 changes: 9 additions & 6 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ _check_augmentable(support, probs) = _check_probs_01(probs) &&
size(probs)[end] + 1 == length(support) ||
throw(err_dim_augmented(support, probs))

err_incompatible_pool(support, classes) = ArgumentError(
"Specified support, $support, not contained in "*
"specified pool, $(levels(classes)). "
)


## AUGMENTING ARRAYS TO MAKE THEM PROBABILITY ARRAYS

Expand Down Expand Up @@ -277,7 +282,7 @@ function _augment_probs(::Val{true},
end

_array_or_scalar(x::Array) = x
_array_or_scalar(x::AbstractArray) = copyto!(similar(Array{eltype(x)}, axes(x)), x)
_array_or_scalar(x::AbstractArray) = copyto!(similar(Array{eltype(x)}, axes(x)), x)
_array_or_scalar(x) = x

## CONSTRUCTORS - FROM DICTIONARY
Expand Down Expand Up @@ -453,14 +458,12 @@ function _UnivariateFinite(support,
_support = classes(v)
else
_classes = classes(pool)
issubset(support, _classes) ||
error("Specified support, $support, not contained in "*
"specified pool, $(levels(classes)). ")
issubset(support, _classes) || throw(err_incompatible_pool(support, _classes))
idxs = getindex.(
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
Ref(CategoricalArrays.DataAPI.invrefpool(_classes)),
support
)
_support = _classes[idxs]
_support = _classes[idxs]
end

# calls core method:
Expand Down
12 changes: 9 additions & 3 deletions test/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ import CategoricalDistributions: classes
@test_logs((:warn, r"No "),
UnivariateFinite(['f', 'q', 's'], [0.7, 0.2, 0.1]))

junk = ["F", "Q", "S"]
@test_throws(
CategoricalDistributions.err_incompatible_pool(junk, classes(v)),
UnivariateFinite(junk, [0.1, 0.9], pool=v),
)

end

@testset "array constructors" begin
Expand All @@ -37,7 +43,7 @@ end

UnivariateFinite(supp, probs, pool=missing, augment=true);

# construction from pool and support does not
# construction from pool and support does not
# consist of categorical elements (See issue #34)
v = categorical(["x", "x", "y", "z", "y", "z", "p"])
probs1 = [0.1, 0.2, 0.7]
Expand Down Expand Up @@ -75,7 +81,7 @@ end
v = categorical(['x', 'x', 'y', 'x', 'z', 'w'])
probs_fillarray = FillArrays.Ones(100, 3)
probs_array = ones(100, 3)

probs1_fillarray = FillArrays.Fill(0.2, 100, 2)
probs1_array = fill(0.2, 100, 2)

Expand All @@ -88,7 +94,7 @@ end
u1_from_fillarray = UnivariateFinite(
['x', 'y', 'z'], probs1_fillarray, pool=v, augment=true
)

@test u_from_array.prob_given_ref == u_from_fillarray.prob_given_ref
@test u1_from_array.prob_given_ref == u1_from_fillarray.prob_given_ref

Expand Down

0 comments on commit daa0879

Please sign in to comment.