Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create option to write CompactPerformanceEvaluation objects to history #215

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
ComputationalResources = "0.3"
Distributions = "0.22,0.23,0.24, 0.25"
LatinHypercubeSampling = "1.7.2"
MLJBase = "1"
MLJBase = "1.3"
ProgressMeter = "1.7.1"
RecipesBase = "0.8,0.9,1"
StatisticalMeasuresBase = "0.1.1"
Expand Down
97 changes: 60 additions & 37 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
acceleration_resampling::AbstractResource
check_measure::Bool
cache::Bool
compact_history::Bool
end

mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Probabilistic
Expand All @@ -69,6 +70,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
acceleration_resampling::AbstractResource
check_measure::Bool
cache::Bool
compact_history::Bool
end

const EitherTunedModel{T,M} =
Expand Down Expand Up @@ -176,6 +178,15 @@ key | value

plus other key/value pairs specific to the `tuning` strategy.

Each element of `history` is a property-accessible object with these properties:

key | value
--------------------|--------------------------------------------------
`measure` | vector of measures (metrics)
`measurement` | vector of measurements, one per measure
`per_fold` | vector of vectors of unaggregated per-fold measurements
`evaluation` | full `PerformanceEvaluation`/`CompactPerformaceEvaluation` object

### Complete list of key-word options

- `model`: `Supervised` model prototype that is cloned and mutated to
Expand Down Expand Up @@ -240,27 +251,35 @@ plus other key/value pairs specific to the `tuning` strategy.
user-suplied data; set to `false` to conserve memory. Speed gains
likely limited to the case `resampling isa Holdout`.

- `compact_history=true`: whether to write `CompactPerformanceEvaluation`](@ref) or
regular [`PerformanceEvaluation`](@ref) objects to the history (accessed via the
`:evaluation` key); the compact form excludes some fields to conserve memory.

"""
function TunedModel(args...; model=nothing,
models=nothing,
tuning=nothing,
resampling=MLJBase.Holdout(),
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operations=nothing,
operation=operations,
ranges=nothing,
range=ranges,
selection_heuristic=NaiveSelection(),
train_best=true,
repeats=1,
n=nothing,
acceleration=default_resource(),
acceleration_resampling=CPU1(),
check_measure=true,
cache=true)
function TunedModel(
args...;
model=nothing,
models=nothing,
tuning=nothing,
resampling=MLJBase.Holdout(),
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operations=nothing,
operation=operations,
ranges=nothing,
range=ranges,
selection_heuristic=NaiveSelection(),
train_best=true,
repeats=1,
n=nothing,
acceleration=default_resource(),
acceleration_resampling=CPU1(),
check_measure=true,
cache=true,
compact_history=true,
)

# user can specify model as argument instead of kwarg:
length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS)
Expand Down Expand Up @@ -339,7 +358,8 @@ function TunedModel(args...; model=nothing,
acceleration,
acceleration_resampling,
check_measure,
cache
cache,
compact_history,
)

if M <: DeterministicTypes
Expand Down Expand Up @@ -582,9 +602,10 @@ function assemble_events!(metamodels,
check_measure = resampling_machine.model.check_measure,
repeats = resampling_machine.model.repeats,
acceleration = resampling_machine.model.acceleration,
cache = resampling_machine.model.cache),
resampling_machine.args...; cache=false) for
_ in 2:length(partitions)]...]
cache = resampling_machine.model.cache,
compact = resampling_machine.model.compact
), resampling_machine.args...; cache=false) for
_ in 2:length(partitions)]...]

@sync for (i, parts) in enumerate(partitions)
Threads.@spawn begin
Expand Down Expand Up @@ -736,21 +757,23 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},

# instantiate resampler (`model` to be replaced with mutated
# clones during iteration below):
resampler = Resampler(model=model,
resampling = deepcopy(tuned_model.resampling),
measure = tuned_model.measure,
weights = tuned_model.weights,
class_weights = tuned_model.class_weights,
operation = tuned_model.operation,
check_measure = tuned_model.check_measure,
repeats = tuned_model.repeats,
acceleration = tuned_model.acceleration_resampling,
cache = tuned_model.cache)
resampler = Resampler(
model=model,
resampling = deepcopy(tuned_model.resampling),
measure = tuned_model.measure,
weights = tuned_model.weights,
class_weights = tuned_model.class_weights,
operation = tuned_model.operation,
check_measure = tuned_model.check_measure,
repeats = tuned_model.repeats,
acceleration = tuned_model.acceleration_resampling,
cache = tuned_model.cache,
compact = tuned_model.compact_history,
)
resampling_machine = machine(resampler, data...; cache=false)
history, state = build!(nothing, n, tuning, model, model_buffer, state,
verbosity, acceleration, resampling_machine)


return finalize(
tuned_model,
model_buffer,
Expand Down Expand Up @@ -867,9 +890,9 @@ function MLJBase.reports_feature_importances(model::EitherTunedModel)
end # This is needed in some cases (e.g tuning a `Pipeline`)

function MLJBase.feature_importances(::EitherTunedModel, fitresult, report)
# fitresult here is a machine created using the best_model obtained
# fitresult here is a machine created using the best_model obtained
# from the tuning process.
# The line below will return `nothing` when the model being tuned doesn't
# The line below will return `nothing` when the model being tuned doesn't
# support feature_importances.
return MLJBase.feature_importances(fitresult)
end
Expand Down
36 changes: 25 additions & 11 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Random
Random.seed!(1234*myid())
using .TestUtilities

begin
begin
N = 30
x1 = rand(N);
x2 = rand(N);
Expand Down Expand Up @@ -157,14 +157,14 @@ end

@testset_accelerated "Feature Importances" accel begin
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
tm0 = TunedModel(
model = trees[1],
measure = rms,
tuning = Grid(),
resampling = CV(nfolds = 5),
range = range(
trees[1], :max_depth, values = 1:10
)
tm0 = TunedModel(
model = trees[1],
measure = rms,
tuning = Grid(),
resampling = CV(nfolds = 5),
range = range(
trees[1], :max_depth, values = 1:10
)
)
@test reports_feature_importances(typeof(tm0))
tm = TunedModel(
Expand Down Expand Up @@ -435,7 +435,7 @@ end
model = DecisionTreeClassifier()
tmodel = TunedModel(models=[model,])
mach = machine(tmodel, X, y)
@test mach isa Machine{<:Any,false}
@test !MLJBase.caches_data(mach)
fit!(mach, verbosity=-1)
@test !isdefined(mach, :data)
MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1])
Expand Down Expand Up @@ -490,7 +490,7 @@ end
@test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2)
end

@testset_accelerated "full evaluation object" accel begin
@testset_accelerated "evaluation object" accel begin
X, y = make_regression(100, 2)
dcr = DeterministicConstantRegressor()

Expand All @@ -504,10 +504,24 @@ end
fit!(homach, verbosity=0);
horep = report(homach)
evaluations = getproperty.(horep.history, :evaluation)
@test first(evaluations) isa MLJBase.CompactPerformanceEvaluation
measurements = getproperty.(evaluations, :measurement)
models = getproperty.(evaluations, :model)
@test all(==(measurements[1]), measurements)
@test all(==(dcr), models)

homodel = TunedModel(
models=fill(dcr, 10),
resampling=Holdout(rng=StableRNG(1234)),
acceleration_resampling=accel,
measure=mae,
compact_history=false,
)
homach = machine(homodel, X, y)
fit!(homach, verbosity=0);
horep = report(homach)
evaluations = getproperty.(horep.history, :evaluation)
@test first(evaluations) isa MLJBase.PerformanceEvaluation
end

true
Loading