-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
205 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
indent = 4 | ||
margin = 92 | ||
always_for_in = true | ||
whitespace_typedefs = true | ||
whitespace_ops_in_indices = true | ||
remove_extra_newlines = true | ||
import_to_using = false | ||
pipe_to_function_call = false | ||
short_to_long_function_def = false | ||
long_to_short_function_def = false | ||
always_use_return = true | ||
whitespace_in_kwargs = true | ||
annotate_untyped_fields_with_any = true | ||
format_docstrings = true | ||
conditional_to_if = true | ||
normalize_line_endings = "unix" | ||
trailing_comma = true | ||
join_lines_based_on_source = true | ||
indent_submodule = true | ||
separate_kwargs_with_semicolon = true | ||
surround_whereop_typeparameters = true | ||
overwrite = true | ||
verbose = true | ||
format_markdown = true | ||
align_struct_field = true | ||
align_pair_arrow = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,9 @@ outofbag | |
bootstrap | ||
iqr | ||
``` | ||
|
||
## Heterogeneous ensembles | ||
|
||
```@docs | ||
Ensemble | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,9 @@ Classifier | |
features | ||
labels | ||
threshold | ||
threshold! | ||
variables | ||
variables! | ||
instance | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Ensemble | ||
An heterogeneous ensemble model is defined as a vector of `SDM`s. | ||
""" | ||
mutable struct Ensemble | ||
models::Vector{<:SDM} | ||
end | ||
|
||
Ensemble(m::T...) where {T <: SDM} = Ensemble([m...]) | ||
|
||
@testitem "We can setup an ensemble" begin | ||
X, y = SDeMo.__demodata() | ||
m1 = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y) | ||
m2 = SDM(ZScore, BIOCLIM, X, y) | ||
ens = Ensemble([m1, m2]) | ||
@test ens isa Ensemble | ||
end | ||
|
||
@testitem "We can setup an ensemble the other way" begin | ||
X, y = SDeMo.__demodata() | ||
m1 = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y) | ||
m2 = SDM(ZScore, BIOCLIM, X, y) | ||
ens = Ensemble(m1, m2) | ||
@test ens isa Ensemble | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
""" | ||
train!(ensemble::Bagging; kwargs...) | ||
Trains all the model in an ensemble model - the keyword arguments are passed to | ||
`train!` for each model. Note that this retrains the *entire* model, which | ||
includes the transformers. | ||
""" | ||
function train!(ensemble::Bagging; kwargs...) | ||
Threads.@threads for m in eachindex(ensemble.models) | ||
train!(ensemble.models[m]; training = ensemble.bags[m][1], kwargs...) | ||
end | ||
train!(ensemble.model; kwargs...) | ||
return ensemble | ||
end | ||
|
||
""" | ||
StatsAPI.predict(ensemble::Bagging, X; consensus = median, kwargs...) | ||
Returns the prediction for the ensemble of models a dataset `X`. The function | ||
used to aggregate the outputs from different models is `consensus` (defaults to | ||
`median`). All other keyword arguments are passed to `predict`. | ||
To get a direct estimate of the variability, the `consensus` function can be | ||
changed to `iqr` (inter-quantile range), or any measure of variance. | ||
""" | ||
function StatsAPI.predict(ensemble::Bagging, X; consensus = median, kwargs...) | ||
ŷ = [predict(component, X; kwargs...) for component in ensemble.models] | ||
ỹ = vec(mapslices(consensus, hcat(ŷ...); dims = 2)) | ||
return isone(length(ỹ)) ? only(ỹ) : ỹ | ||
end | ||
|
||
""" | ||
StatsAPI.predict(ensemble::Bagging; kwargs...) | ||
Predicts the ensemble model for all training data. | ||
""" | ||
function StatsAPI.predict(ensemble::Bagging; kwargs...) | ||
return StatsAPI.predict(ensemble, ensemble.model.X; kwargs...) | ||
end | ||
|
||
""" | ||
train!(ensemble::Ensemble; kwargs...) | ||
Trains all the model in an heterogeneous ensemble model - the keyword arguments | ||
are passed to `train!` for each model. Note that this retrains the *entire* | ||
model, which includes the transformers. | ||
The keywod arguments are passed to `train!` and can include the `training` | ||
indices. | ||
""" | ||
function train!(ensemble::Ensemble; kwargs...) | ||
Threads.@threads for m in eachindex(ensemble.models) | ||
train!(ensemble.models[m]; kwargs...) | ||
end | ||
return ensemble | ||
end | ||
|
||
""" | ||
StatsAPI.predict(ensemble::Ensemble, X; consensus = median, kwargs...) | ||
Returns the prediction for the heterogeneous ensemble of models a dataset `X`. | ||
The function used to aggregate the outputs from different models is `consensus` | ||
(defaults to `median`). All other keyword arguments are passed to `predict`. | ||
To get a direct estimate of the variability, the `consensus` function can be | ||
changed to `iqr` (inter-quantile range), or any measure of variance. | ||
""" | ||
function StatsAPI.predict(ensemble::Ensemble, X; consensus = median, kwargs...) | ||
ŷ = [predict(component, X; kwargs...) for component in ensemble.models] | ||
ỹ = vec(mapslices(consensus, hcat(ŷ...); dims = 2)) | ||
return isone(length(ỹ)) ? only(ỹ) : ỹ | ||
end | ||
|
||
""" | ||
StatsAPI.predict(ensemble::Ensemble; kwargs...) | ||
Predicts the heterogeneous ensemble model for all training data. | ||
""" | ||
function StatsAPI.predict(ensemble::Ensemble; kwargs...) | ||
return StatsAPI.predict(ensemble, features(first(ensemble.models)); kwargs...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters