Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDeMo bugfix and QOL #306

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion SDeMo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SDeMo"
uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467"
authors = ["Timothée Poisot <[email protected]>"]
version = "0.0.5"
version = "0.0.6"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
54 changes: 31 additions & 23 deletions SDeMo/docs/src/demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ size(X)

# ## Setting up the model

## We will start with an initial model that uses a PCA to transform the data, and
# We will start with an initial model that uses a PCA to transform the data, and
# then a Naive Bayes Classifier for the classification. Note that this is the
# partial syntax where we use the default threshold, and all the variables:

Expand All @@ -53,11 +53,11 @@ cv = crossvalidate(sdm, folds);
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [mean(measure.(set)) for measure in measures, set in cv]
pretty_table(
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
)

# Assuming we want to get a simple idea of what the MCC is for the validation
Expand All @@ -69,6 +69,14 @@ mcc.(cv.validation)

ci(cv.validation, mcc)

# We can also get the same output by calling a function on a vector of `ConfusionMatrix`, *e.g.*

mcc(cv.validation)

# Adding the `true` argument returns a tuple with the 95% CI:

mcc(cv.validation, true)

# ## Variable selection

# We will now select variables using forward selection, but with the added
Expand All @@ -87,11 +95,11 @@ cv2 = crossvalidate(sdm, folds)
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [mean(measure.(set)) for measure in measures, set in cv2]
pretty_table(
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
)

# Quite clearly! Before thinking about the relative importance of variables, we
Expand Down Expand Up @@ -136,11 +144,11 @@ varimp = variableimportance(sdm, folds)
# In relative terms, this is:

pretty_table(
hcat(variables(sdm), varimp ./ sum(varimp));
alignment=[:l, :c],
backend=Val(:markdown),
header=["Variable", "Importance"],
formatters=(ft_printf("%5.3f", 2), ft_printf("%d", 1))
hcat(variables(sdm), varimp ./ sum(varimp));
alignment=[:l, :c],
backend=Val(:markdown),
header=["Variable", "Importance"],
formatters=(ft_printf("%5.3f", 2), ft_printf("%d", 1))
)

# ## Partial response curve
Expand Down Expand Up @@ -184,8 +192,8 @@ f = Figure()
ax = Axis(f[1, 1])
prx, pry = partialresponse(sdm, 1; inflated=false, threshold=false)
for i in 1:200
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
end
lines!(ax, prx, pry, color=:black, linewidth=4)
current_figure() #hide
Expand Down Expand Up @@ -279,11 +287,11 @@ cf = counterfactual(sdm, instance(sdm, inst; strict=false), target, 200.0; thres
# is:

pretty_table(
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Variable", "Obs.", "Counterf."],
formatters=(ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1))
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Variable", "Obs.", "Counterf."],
formatters=(ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1))
)

# We can check the prediction that would be made on the counterfactual:
Expand Down
59 changes: 38 additions & 21 deletions SDeMo/src/crossvalidation/crossvalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,15 @@ function montecarlo(y, X; n = 100, kwargs...)
return [holdout(y, X; kwargs...) for _ in 1:n]
end


@testitem "We can do montecarlo validation" begin
X, y = SDeMo.__demodata()
model = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
folds = montecarlo(model; n=10)
folds = montecarlo(model; n = 10)
cv = crossvalidate(model, folds)
@test eltype(cv.validation) <: ConfusionMatrix
@test length(cv.training) == 10
end


"""
kfold(y, X; k = 10, permute = true)

Expand Down Expand Up @@ -113,32 +111,39 @@ function kfold(y, X; k = 10, permute = true)
return folds
end


@testitem "We can do kfold validation" begin
X, y = SDeMo.__demodata()
model = SDM(MultivariateTransform{PCA}, NaiveBayes, X, y)
folds = kfold(model; k=12)
folds = kfold(model; k = 12)
cv = crossvalidate(model, folds)
@test eltype(cv.validation) <: ConfusionMatrix
@test length(cv.training) == 12
end


for op in (:leaveoneout, :holdout, :montecarlo, :kfold)
eval(quote
"""
$($op)(sdm::SDM)

Version of `$($op)` using the instances and labels of an SDM.
"""
$op(sdm::SDM, args...; kwargs...) = $op(labels(sdm), features(sdm), args...; kwargs...)
end)
eval(
quote
"""
$($op)(sdm::SDM)

Version of `$($op)` using the instances and labels of an SDM.
"""
$op(sdm::SDM, args...; kwargs...) =
$op(labels(sdm), features(sdm), args...; kwargs...)
"""
$($op)(sdm::Bagging)

Version of `$($op)` using the instances and labels of a bagged SDM. In this case, the instances of the model used as a reference to build the bagged model are used.
"""
$op(sdm::Bagging, args...; kwargs...) = $op(sdm.model, args...; kwargs...)
end,
)
end

@testitem "We can split data in an SDM" begin
X, y = SDeMo.__demodata()
sdm = SDM(MultivariateTransform{PCA}(), BIOCLIM(), 0.01, X, y, 1:size(X, 1))
folds = montecarlo(sdm; n=10)
folds = montecarlo(sdm; n = 10)
@test length(folds) == 10
end

Expand All @@ -153,27 +158,39 @@ This method returns two vectors of `ConfusionMatrix`, with the confusion matrix
for each set of validation data first, and the confusion matrix for the training
data second.
"""
function crossvalidate(sdm, folds; thr = nothing, kwargs...)
function crossvalidate(sdm::T, folds; thr = nothing, kwargs...) where {T <: AbstractSDM}
Cv = zeros(ConfusionMatrix, length(folds))
Ct = zeros(ConfusionMatrix, length(folds))
models = [deepcopy(sdm) for _ in Base.OneTo(Threads.nthreads())]
Threads.@threads for i in eachindex(folds)
trn, val = folds[i]
train!(models[Threads.threadid()]; training = trn, kwargs...)
pred = predict(models[Threads.threadid()], features(sdm)[:, val]; threshold = false)
ontrn = predict(models[Threads.threadid()], features(sdm)[:, trn]; threshold = false)
ontrn =
predict(models[Threads.threadid()], features(sdm)[:, trn]; threshold = false)
thr = isnothing(thr) ? threshold(sdm) : thr
Cv[i] = ConfusionMatrix(pred, labels(sdm)[val], thr)
Ct[i] = ConfusionMatrix(ontrn, labels(sdm)[trn], thr)
end
return (validation = Cv, training = Ct)
end

@testitem "We can crossvalidate an SDM" begin
@testitem "We can cross-validate an SDM" begin
X, y = SDeMo.__demodata()
sdm = SDM(MultivariateTransform{PCA}(), BIOCLIM(), 0.5, X, y, [1,2,12])
sdm = SDM(MultivariateTransform{PCA}(), BIOCLIM(), 0.5, X, y, [1, 2, 12])
train!(sdm)
cv = crossvalidate(sdm, kfold(sdm; k=15))
cv = crossvalidate(sdm, kfold(sdm; k = 15))
@test eltype(cv.validation) <: ConfusionMatrix
@test eltype(cv.training) <: ConfusionMatrix
end

@testitem "We can cross-validate an ensemble model using the consensus keyword" begin
using Statistics
X, y = SDeMo.__demodata()
sdm = SDM(MultivariateTransform{PCA}(), NaiveBayes(), 0.5, X, y, [1, 2, 12])
ens = Bagging(sdm, 10)
train!(ens)
cv = crossvalidate(ens, kfold(ens; k = 15); consensus = median)
@test eltype(cv.validation) <: ConfusionMatrix
@test eltype(cv.training) <: ConfusionMatrix
end
end
51 changes: 50 additions & 1 deletion SDeMo/src/crossvalidation/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,53 @@ specificity(M::ConfusionMatrix) = tnr(M)

Alias for `ppv`, the positive predictive value
"""
precision(M::ConfusionMatrix) = ppv(M)
precision(M::ConfusionMatrix) = ppv(M)

for op in (
:tpr,
:tnr,
:fpr,
:fnr,
:ppv,
:npv,
:fdir,
:fomr,
:plr,
:nlr,
:accuracy,
:balancedaccuracy,
:f1,
:fscore,
:trueskill,
:markedness,
:dor,
:κ,
:mcc,
:specificity,
:sensitivity,
:recall,
:precision,
)
eval(
quote
"""
$($op)(C::Vector{ConfusionMatrix}, full::Bool=false)

Version of `$($op)` using a vector of confusion matrices. Returns the mean, and when the second argument is `true`, returns a tuple where the second argument is the CI.
"""
function $op(
C::Vector{ConfusionMatrix},
full::Bool = false,
args...;
kwargs...,
)
m = $op.(C, args...; kwargs...)
if full
return (mean(m), ci(C, $op))
else
return mean(m)
end
end
end,
)
end
7 changes: 5 additions & 2 deletions SDeMo/src/ensembles/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ Trains all the model in an ensemble model - the keyword arguments are passed to
includes the transformers.
"""
function train!(ensemble::Bagging; kwargs...)
# The ensemble model can be given a consensus argument, in which can we drop it for
# training as it's relevant for prediction only
trainargs = filter(kw -> kw.first != :consensus, kwargs)
Threads.@threads for m in eachindex(ensemble.models)
train!(ensemble.models[m]; training = ensemble.bags[m][1], kwargs...)
train!(ensemble.models[m]; training = ensemble.bags[m][1], trainargs...)
end
train!(ensemble.model; kwargs...)
train!(ensemble.model; trainargs...)
return ensemble
end

Expand Down
Loading