Skip to content

Commit

Permalink
adpat Resampler to adjusted evaluate signature
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Apr 24, 2024
1 parent b0818a3 commit 2bc3bec
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,7 @@ end
check_measure=true,
per_observation=true,
logger=nothing,
compact=false,
)
Resampling model wrapper, used internally by the `fit` method of `TunedModel` instances
Expand Down Expand Up @@ -1507,6 +1508,7 @@ mutable struct Resampler{S, L} <: Model
cache::Bool
per_observation::Bool
logger::L
compact::Bool
end

# Some traits are markded as `missing` because we cannot determine
Expand Down Expand Up @@ -1550,6 +1552,7 @@ function Resampler(
cache=true,
per_observation=true,
logger=nothing,
compact=false,
)
resampler = Resampler(
model,
Expand All @@ -1564,6 +1567,7 @@ function Resampler(
cache,
per_observation,
logger,
compact,
)
message = MLJModelInterface.clean!(resampler)
isempty(message) || @warn message
Expand Down Expand Up @@ -1612,6 +1616,7 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)
resampler.per_observation,
resampler.logger,
resampler.resampling,
resampler.compact,
)

fitresult = (machine = mach, evaluation = e)
Expand Down Expand Up @@ -1685,6 +1690,7 @@ function MLJModelInterface.update(
resampler.per_observation,
resampler.logger,
resampler.resampling,
resampler.compact,
)
report = (evaluation = e, )
fitresult = (machine=mach2, evaluation=e)
Expand Down
2 changes: 1 addition & 1 deletion test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ end
X, y = make_blobs(10)
e = evaluate(model, X, y)
ec = evaluate(model, X, y, compact=true)
@test e isa PeformanceEvaluation
@test e isa PerformanceEvaluation
@test ec isa CompactPerformanceEvaluation
@test startswith(sprint(show, MIME("text/plain"), e), "PerformanceEvaluation")
@test startswith(sprint(show, MIME("text/plain"), ec), "CompactPerformanceEvaluation")
Expand Down

0 comments on commit 2bc3bec

Please sign in to comment.