Skip to content

Commit

Permalink
doc(demo): fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisot committed Oct 15, 2024
1 parent 64fa461 commit bff9518
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 33 deletions.
62 changes: 31 additions & 31 deletions SDeMo/docs/src/demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ size(X)

# ## Setting up the model

## We will start with an initial model that uses a PCA to transform the data, and
# We will start with an initial model that uses a PCA to transform the data, and
# then a Naive Bayes Classifier for the classification. Note that this is the
# partial syntax where we use the default threshold, and all the variables:

Expand All @@ -53,11 +53,11 @@ cv = crossvalidate(sdm, folds);
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [mean(measure.(set)) for measure in measures, set in cv]
pretty_table(
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
)

# Assuming we want to get a simple idea of what the MCC is for the validation
Expand All @@ -69,6 +69,14 @@ mcc.(cv.validation)

ci(cv.validation, mcc)

# We can also get the same output by calling a function on a vector of `ConfusionMatrix`, *e.g.*

mcc(cv.validation)

# Adding the `true` argument returns a tuple with the 95% CI:

mcc(cv.validation, true)

# ## Variable selection

# We will now select variables using forward selection, but with the added
Expand All @@ -87,24 +95,16 @@ cv2 = crossvalidate(sdm, folds)
measures = [mcc, balancedaccuracy, ppv, npv, trueskill, markedness]
cvresult = [mean(measure.(set)) for measure in measures, set in cv2]
pretty_table(
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
hcat(string.(measures), cvresult);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Measure", "Validation", "Training"],
formatters=ft_printf("%5.3f", [2, 3])
)

# Quite clearly! Before thinking about the relative importance of variables, we
# will take a look at the thresold.

# As a side-note, we can get the average of many performance measures using *e.g.*

f1(cv2.validation)

# And the mean and 95% confidence interval with

balancedaccuracy(cv2.validation, true)

# ## Moving threshold classification

# The `crossvalidate` function comes with an optional argument to specify the
Expand Down Expand Up @@ -144,11 +144,11 @@ varimp = variableimportance(sdm, folds)
# In relative terms, this is:

pretty_table(
hcat(variables(sdm), varimp ./ sum(varimp));
alignment=[:l, :c],
backend=Val(:markdown),
header=["Variable", "Importance"],
formatters=(ft_printf("%5.3f", 2), ft_printf("%d", 1))
hcat(variables(sdm), varimp ./ sum(varimp));
alignment=[:l, :c],
backend=Val(:markdown),
header=["Variable", "Importance"],
formatters=(ft_printf("%5.3f", 2), ft_printf("%d", 1))
)

# ## Partial response curve
Expand Down Expand Up @@ -192,8 +192,8 @@ f = Figure()
ax = Axis(f[1, 1])
prx, pry = partialresponse(sdm, 1; inflated=false, threshold=false)
for i in 1:200
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
ix, iy = partialresponse(sdm, 1; inflated=true, threshold=false)
lines!(ax, ix, iy, color=(:grey, 0.5))
end
lines!(ax, prx, pry, color=:black, linewidth=4)
current_figure() #hide
Expand Down Expand Up @@ -287,11 +287,11 @@ cf = counterfactual(sdm, instance(sdm, inst; strict=false), target, 200.0; thres
# is:

pretty_table(
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Variable", "Obs.", "Counterf."],
formatters=(ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1))
hcat(variables(sdm), instance(sdm, inst), cf[variables(sdm)]);
alignment=[:l, :c, :c],
backend=Val(:markdown),
header=["Variable", "Obs.", "Counterf."],
formatters=(ft_printf("%4.1f", [2, 3]), ft_printf("%d", 1))
)

# We can check the prediction that would be made on the counterfactual:
Expand Down
4 changes: 2 additions & 2 deletions SDeMo/src/crossvalidation/crossvalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ for op in (:leaveoneout, :holdout, :montecarlo, :kfold)
$op(sdm::SDM, args...; kwargs...) =
$op(labels(sdm), features(sdm), args...; kwargs...)
"""
$($op)(sdm::Bagging)
$($op)(sdm::Bagging)
Version of `$($op)` using the instances and labels of a bagged SDM.
Version of `$($op)` using the instances and labels of a bagged SDM. In this case, the instances of the model used as a reference to build the bagged model are used.
"""
$op(sdm::Bagging, args...; kwargs...) = $op(sdm.model, args...; kwargs...)
end,
Expand Down

0 comments on commit bff9518

Please sign in to comment.