From c0c7f8ab2768b720de897079ffd0f3c17f26c35d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 24 Apr 2024 17:10:41 +1200 Subject: [PATCH 1/2] add InSample resampling strategy to close #967 --- src/MLJBase.jl | 2 +- src/resampling.jl | 48 ++++++++++++++++++++++++++++++++++++++++++++-- test/resampling.jl | 16 ++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/MLJBase.jl b/src/MLJBase.jl index bde58cc4..f1ebe249 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -291,7 +291,7 @@ export machines, sources, Stack, export TransformedTargetModel # resampling.jl: -export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV, +export ResamplingStrategy, InSample, Holdout, CV, StratifiedCV, TimeSeriesCV, evaluate!, Resampler, PerformanceEvaluation # `MLJType` and the abstract `Model` subtypes are exported from within diff --git a/src/resampling.jl b/src/resampling.jl index 3759e136..382f915f 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -110,6 +110,50 @@ function shuffle_and_rng(shuffle, rng) return shuffle, rng end +# ---------------------------------------------------------------- +# InSample + +""" + in_sample = InSample() + +Instantiate an `InSample` resampling strategy, for use in `evaluate!`, `evaluate` and in +tuning. In this strategy the train and test sets are the same, and consist of all +observations specified by the `rows` keyword argument. If `rows` is not specified, all +supplied rows are used. + +# Example + +```julia +using MLJBase, MLJModels + +X, y = make_blobs() # a table and a vector +model = ConstantClassifier() +train, test = partition(eachindex(y), 0.7) # train:test = 70:30 +``` + +Compute in-sample (training) loss: + +```julia +evaluate(model, X, y, resampling=InSample(), rows=train, measure=brier_loss) +``` + +Compute the out-of-sample loss: + +```julia +evaluate(model, X, y, resampling=[(train, test),], measure=brier_loss) +``` + +Or equivalently: + +```julia +evaluate(model, X, y, resampling=Holdout(fraction_train=0.7), measure=brier_loss) +``` + +""" +struct InSample <: ResamplingStrategy end + +train_test_pairs(::InSample, rows) = [(rows, rows),] + # ---------------------------------------------------------------- # Holdout @@ -118,7 +162,7 @@ end shuffle=nothing, rng=nothing) -Holdout resampling strategy, for use in `evaluate!`, `evaluate` and in +Instantiate a `Holdout` resampling strategy, for use in `evaluate!`, `evaluate` and in tuning. train_test_pairs(holdout, rows) @@ -345,7 +389,7 @@ end rng=Random.GLOBAL_RNG) Stratified cross-validation resampling strategy, for use in -`evaluate!`, `evaluate` and in tuning. Applies only to classification +`evaluate!`, `evaluate` and intuning. Applies only to classification problems (`OrderedFactor` or `Multiclass` targets). train_test_pairs(stratified_cv, rows, y) diff --git a/test/resampling.jl b/test/resampling.jl index d27af319..288cd967 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -364,6 +364,22 @@ end end end +@testset "insample" begin + rows = rand(Int, 100) + @test MLJBase.train_test_pairs(InSample(), rows) == [(rows, rows),] + + X, y = make_regression(20) + model = Models.DeterministicConstantRegressor() + + # all rows: + e = evaluate(model, X, y, resampling=InSample(), measure=rms) + @test e.measurement[1] ≈ std(y, corrected=false) + + # subsample of rows: + e = evaluate(model, X, y, resampling=InSample(), measure=rms, rows=1:7) + @test e.measurement[1] ≈ std(y[1:7], corrected=false) +end + @testset_accelerated "holdout" accel begin x1 = ones(4) x2 = ones(4) From 2c85c301010f0d4f14c8ee4b6681a69cd39bec72 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 6 May 2024 12:21:44 +1200 Subject: [PATCH 2/2] typo identified in review --- src/resampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/resampling.jl b/src/resampling.jl index af543582..8fc2c948 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -389,7 +389,7 @@ end rng=Random.GLOBAL_RNG) Stratified cross-validation resampling strategy, for use in -`evaluate!`, `evaluate` and intuning. Applies only to classification +`evaluate!`, `evaluate` and in tuning. Applies only to classification problems (`OrderedFactor` or `Multiclass` targets). train_test_pairs(stratified_cv, rows, y)