diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 2eb0c97..a77361e 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -33,7 +33,7 @@ warn_double_spec(arg, model) = const ProbabilisticTypes = Union{Probabilistic, MLJBase.MLJModelInterface.ProbabilisticDetector} const DeterministicTypes = Union{Deterministic, MLJBase.MLJModelInterface.DeterministicDetector} -mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deterministic +mutable struct DeterministicTunedModel{T,M<:DeterministicTypes,L} <: MLJBase.Deterministic model::M tuning::T # tuning strategy resampling # resampling strategy @@ -51,9 +51,10 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter check_measure::Bool cache::Bool compact_history::Bool + logger::L end -mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Probabilistic +mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes,L} <: MLJBase.Probabilistic model::M tuning::T # tuning strategy resampling # resampling strategy @@ -71,10 +72,11 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba check_measure::Bool cache::Bool compact_history::Bool + logger::L end -const EitherTunedModel{T,M} = - Union{DeterministicTunedModel{T,M},ProbabilisticTunedModel{T,M}} +const EitherTunedModel{T,M,L} = + Union{DeterministicTunedModel{T,M,L},ProbabilisticTunedModel{T,M,L}} MLJBase.caches_data_by_default(::Type{<:EitherTunedModel}) = false @@ -279,6 +281,7 @@ function TunedModel( check_measure=true, cache=true, compact_history=true, + logger=nothing ) # user can specify model as argument instead of kwarg: @@ -342,6 +345,9 @@ function TunedModel( # get the tuning type parameter: T = typeof(tuning) + # get the logger type parameter: + L = typeof(logger) + args = ( model, tuning, @@ -360,12 +366,13 @@ function TunedModel( check_measure, cache, compact_history, + logger ) if M <: DeterministicTypes - tuned_model = DeterministicTunedModel{T,M}(args...) + tuned_model = DeterministicTunedModel{T,M,L}(args...) elseif M <: ProbabilisticTypes - tuned_model = ProbabilisticTunedModel{T,M}(args...) + tuned_model = ProbabilisticTunedModel{T,M,L}(args...) else throw(ERR_MODEL_TYPE) end @@ -591,7 +598,7 @@ function assemble_events!(metamodels, end end # One resampling_machine per task - machs = [resampling_machine, + machs = [resampling_machine, [machine(Resampler( model= resampling_machine.model.model, resampling = resampling_machine.model.resampling, @@ -603,9 +610,9 @@ function assemble_events!(metamodels, repeats = resampling_machine.model.repeats, acceleration = resampling_machine.model.acceleration, cache = resampling_machine.model.cache, - compact = resampling_machine.model.compact - ), resampling_machine.args...; cache=false) for - _ in 2:length(partitions)]...] + compact = resampling_machine.model.compact, + logger = resampling_machine.model.logger), + resampling_machine.args...; cache=false) for _ in 2:length(partitions)]...] @sync for (i, parts) in enumerate(partitions) Threads.@spawn begin @@ -740,8 +747,8 @@ function finalize(tuned_model, return fitresult, meta_state, report end -function MLJBase.fit(tuned_model::EitherTunedModel{T,M}, - verbosity::Integer, data...) where {T,M} +function MLJBase.fit(tuned_model::EitherTunedModel{T,M,L}, + verbosity::Integer, data...) where {T,M,L} tuning = tuned_model.tuning model = tuned_model.model _range = tuned_model.range @@ -769,6 +776,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M}, acceleration = tuned_model.acceleration_resampling, cache = tuned_model.cache, compact = tuned_model.compact_history, + logger = tuned_model.logger ) resampling_machine = machine(resampler, data...; cache=false) history, state = build!(nothing, n, tuning, model, model_buffer, state, @@ -900,9 +908,9 @@ end ## METADATA MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true -MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M = +MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} = MLJBase.supports_weights(M) -MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M = +MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} = MLJBase.supports_class_weights(M) MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) = "MLJTuning.ProbabilisticTunedModel" @@ -914,9 +922,9 @@ MLJBase.package_uuid(::Type{<:EitherTunedModel}) = MLJBase.package_url(::Type{<:EitherTunedModel}) = "https://github.com/alan-turing-institute/MLJTuning.jl" MLJBase.package_license(::Type{<:EitherTunedModel}) = "MIT" -MLJBase.is_pure_julia(::Type{<:EitherTunedModel{T,M}}) where {T,M} = +MLJBase.is_pure_julia(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} = MLJBase.is_pure_julia(M) -MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M}}) where {T,M} = +MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} = MLJBase.input_scitype(M) -MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M}}) where {T,M} = +MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} = MLJBase.target_scitype(M)