Skip to content

Commit

Permalink
feat(demo): heterogeneous ensembles
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisot committed Sep 22, 2024
1 parent 5f9f982 commit 58a8200
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 52 deletions.
26 changes: 26 additions & 0 deletions SDeMo/.JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
indent = 4
margin = 92
always_for_in = true
whitespace_typedefs = true
whitespace_ops_in_indices = true
remove_extra_newlines = true
import_to_using = false
pipe_to_function_call = false
short_to_long_function_def = false
long_to_short_function_def = false
always_use_return = true
whitespace_in_kwargs = true
annotate_untyped_fields_with_any = true
format_docstrings = true
conditional_to_if = true
normalize_line_endings = "unix"
trailing_comma = true
join_lines_based_on_source = true
indent_submodule = true
separate_kwargs_with_semicolon = true
surround_whereop_typeparameters = true
overwrite = true
verbose = true
format_markdown = true
align_struct_field = true
align_pair_arrow = true
22 changes: 22 additions & 0 deletions SDeMo/docs/src/demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,28 @@ train!(ensemble)
uncert = predict(ensemble; consensus=iqr, threshold=false)
hist(uncert, color=:grey; axis=(; xlabel="Uncertainty (IQR)"))

# ## Heterogeneous ensembles

# We can setup an heterogeneous ensemble model by passing several SDMs to
# `Ensemble`:

m1 = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
m2 = SDM(RawData, BIOCLIM, X, y)
m3 = SDM(MultivariateTransform{PCA}, BIOCLIM, X, y)
variables!(m2, [1, 12])
hm = Ensemble(m1, m2, m3)

# We can train this model in the same way:

train!(hm)

# And get predictions:

predict(hm; consensus=median, threshold=false)[1:10]

# Note taht *for now*, `Ensemble` and `Bagging` models are not supported by
# methods like `variableimportance`, `partialresponse`, etc.

# ## Explaining predictions

# We can perform the (MCMC version of) Shapley values measurement, using the
Expand Down
6 changes: 6 additions & 0 deletions SDeMo/docs/src/ensembles.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ outofbag
bootstrap
iqr
```

## Heterogeneous ensembles

```@docs
Ensemble
```
2 changes: 2 additions & 0 deletions SDeMo/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Classifier
features
labels
threshold
threshold!
variables
variables!
instance
```

Expand Down
11 changes: 8 additions & 3 deletions SDeMo/src/SDeMo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include("models.jl")
export Transformer, Classifier
export SDM
export threshold, features, labels, variables, instance
export threshold!, variables!

# Univariate transforms
include("transformers/univariate.jl")
Expand All @@ -37,11 +38,15 @@ export NaiveBayes
include("classifiers/bioclim.jl")
export BIOCLIM

# Bagging
include("bagging/bootstrap.jl")
include("bagging/pipeline.jl")
# Bagging and ensembles
include("ensembles/bagging.jl")
export Bagging, outofbag, bootstrap

include("ensembles/ensemble.jl")
export Ensemble

include("ensembles/pipeline.jl")

# Main pipeline
include("pipeline.jl")
export reset!, train!, predict
Expand Down
39 changes: 0 additions & 39 deletions SDeMo/src/bagging/pipeline.jl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
bootstrap(y, X; n = 50)
"""
function bootstrap(y, X; n = 50)
function bootstrap(y, X; n=50)
@assert size(y, 1) == size(X, 2)
bags = []
for _ in 1:n
inbag = sample(1:size(X, 2), size(X, 2); replace = true)
inbag = sample(1:size(X, 2), size(X, 2); replace=true)
outbag = setdiff(axes(X, 2), inbag)
push!(bags, (inbag, outbag))
end
Expand All @@ -24,7 +24,7 @@ end
"""
mutable struct Bagging
model::SDM
bags::Vector{Tuple{Vector{Int64}, Vector{Int64}}}
bags::Vector{Tuple{Vector{Int64},Vector{Int64}}}
models::Vector{SDM}
end

Expand All @@ -43,7 +43,7 @@ end
Creates a bag from SDM
"""
function Bagging(model::SDM, n::Integer)
bags = bootstrap(labels(model), features(model); n = n)
bags = bootstrap(labels(model), features(model); n=n)
return Bagging(model, bags, [deepcopy(model) for _ in eachindex(bags)])
end

Expand Down
26 changes: 26 additions & 0 deletions SDeMo/src/ensembles/ensemble.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Ensemble
An heterogeneous ensemble model is defined as a vector of `SDM`s.
"""
mutable struct Ensemble
models::Vector{<:SDM}
end

Ensemble(m::T...) where {T <: SDM} = Ensemble([m...])

@testitem "We can setup an ensemble" begin
X, y = SDeMo.__demodata()
m1 = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
m2 = SDM(ZScore, BIOCLIM, X, y)
ens = Ensemble([m1, m2])
@test ens isa Ensemble
end

@testitem "We can setup an ensemble the other way" begin
X, y = SDeMo.__demodata()
m1 = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
m2 = SDM(ZScore, BIOCLIM, X, y)
ens = Ensemble(m1, m2)
@test ens isa Ensemble
end
81 changes: 81 additions & 0 deletions SDeMo/src/ensembles/pipeline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
train!(ensemble::Bagging; kwargs...)
Trains all the model in an ensemble model - the keyword arguments are passed to
`train!` for each model. Note that this retrains the *entire* model, which
includes the transformers.
"""
function train!(ensemble::Bagging; kwargs...)
Threads.@threads for m in eachindex(ensemble.models)
train!(ensemble.models[m]; training = ensemble.bags[m][1], kwargs...)
end
train!(ensemble.model; kwargs...)
return ensemble
end

"""
StatsAPI.predict(ensemble::Bagging, X; consensus = median, kwargs...)
Returns the prediction for the ensemble of models a dataset `X`. The function
used to aggregate the outputs from different models is `consensus` (defaults to
`median`). All other keyword arguments are passed to `predict`.
To get a direct estimate of the variability, the `consensus` function can be
changed to `iqr` (inter-quantile range), or any measure of variance.
"""
function StatsAPI.predict(ensemble::Bagging, X; consensus = median, kwargs...)
= [predict(component, X; kwargs...) for component in ensemble.models]
= vec(mapslices(consensus, hcat(ŷ...); dims = 2))
return isone(length(ỹ)) ? only(ỹ) :
end

"""
StatsAPI.predict(ensemble::Bagging; kwargs...)
Predicts the ensemble model for all training data.
"""
function StatsAPI.predict(ensemble::Bagging; kwargs...)
return StatsAPI.predict(ensemble, ensemble.model.X; kwargs...)
end

"""
train!(ensemble::Ensemble; kwargs...)
Trains all the model in an heterogeneous ensemble model - the keyword arguments
are passed to `train!` for each model. Note that this retrains the *entire*
model, which includes the transformers.
The keywod arguments are passed to `train!` and can include the `training`
indices.
"""
function train!(ensemble::Ensemble; kwargs...)
Threads.@threads for m in eachindex(ensemble.models)
train!(ensemble.models[m]; kwargs...)
end
return ensemble
end

"""
StatsAPI.predict(ensemble::Ensemble, X; consensus = median, kwargs...)
Returns the prediction for the heterogeneous ensemble of models a dataset `X`.
The function used to aggregate the outputs from different models is `consensus`
(defaults to `median`). All other keyword arguments are passed to `predict`.
To get a direct estimate of the variability, the `consensus` function can be
changed to `iqr` (inter-quantile range), or any measure of variance.
"""
function StatsAPI.predict(ensemble::Ensemble, X; consensus = median, kwargs...)
= [predict(component, X; kwargs...) for component in ensemble.models]
= vec(mapslices(consensus, hcat(ŷ...); dims = 2))
return isone(length(ỹ)) ? only(ỹ) :
end

"""
StatsAPI.predict(ensemble::Ensemble; kwargs...)
Predicts the heterogeneous ensemble model for all training data.
"""
function StatsAPI.predict(ensemble::Ensemble; kwargs...)
return StatsAPI.predict(ensemble, features(first(ensemble.models)); kwargs...)
end
30 changes: 24 additions & 6 deletions SDeMo/src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ In addition, the SDM carries with it the training features and labels, as well
as a vector of indices indicating which variables are actually used by the
model.
"""
mutable struct SDM{F,L}
mutable struct SDM{F, L}
transformer::Transformer
classifier::Classifier
τ::Number # Threshold
Expand All @@ -34,14 +34,19 @@ mutable struct SDM{F,L}
v::AbstractVector # Variables
end

function SDM(::Type{TF}, ::Type{CF}, X::Matrix{T}, y::Vector{Bool}) where {TF <: Transformer, CF <: Classifier, T <: Number}
function SDM(
::Type{TF},
::Type{CF},
X::Matrix{T},
y::Vector{Bool},
) where {TF <: Transformer, CF <: Classifier, T <: Number}
return SDM(
TF(),
CF(),
zero(CF),
X,
y,
collect(1:size(X,1))
collect(1:size(X, 1)),
)
end

Expand All @@ -53,6 +58,13 @@ to be a presence.
"""
threshold(sdm::SDM) = sdm.τ

"""
threshold!(sdm::SDM, τ)
Sets the value of the threshold.
"""
threshold!(sdm::SDM, τ) = sdm.τ = τ

"""
features(sdm::SDM)
Expand All @@ -62,7 +74,6 @@ output of this function *will* change the content of the SDM features.
"""
features(sdm::SDM) = sdm.X


"""
features(sdm::SDM, n)
Expand All @@ -75,7 +86,7 @@ features(sdm::SDM, n) = sdm.X[n, :]
Returns the *n*-th instance stored in the field `X` of the SDM.
"""
function instance(sdm::SDM, n; strict=true)
function instance(sdm::SDM, n; strict = true)
if strict
return features(sdm)[variables(sdm), n]
else
Expand All @@ -98,4 +109,11 @@ Returns the list of variables used by the SDM -- these *may* be ordered by
importance. This does not return a copy of the variables array, but the array
itself.
"""
variables(sdm::SDM) = sdm.v
variables(sdm::SDM) = sdm.v

"""
variables!(sdm::SDM, v)
Sets the list of variables.
"""
variables!(sdm::SDM, v) = sdm.v = copy(v)
6 changes: 6 additions & 0 deletions SDeMo/src/utilities/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ function Base.show(io::IO, ensemble::Bagging)
return print(io, join(strs, "\n"))
end

function Base.show(io::IO, ensemble::Ensemble)
strs = ["\t $(m)" for m in ensemble.models]
pushfirst!(strs, "An ensemble model with:")
return print(io, join(strs, "\n"))
end

function Base.show(io::IO, sdm::SDM)
strs = [
"$(typeof(sdm.transformer))",
Expand Down

0 comments on commit 58a8200

Please sign in to comment.