Skip to content

Commit

Permalink
Merge pull request #193 from pebeto/dev
Browse files Browse the repository at this point in the history
Adding loggers into TunedModels
  • Loading branch information
ablaom authored May 21, 2024
2 parents bb59cae + 2b63fa8 commit dc1d6d4
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ warn_double_spec(arg, model) =
const ProbabilisticTypes = Union{Probabilistic, MLJBase.MLJModelInterface.ProbabilisticDetector}
const DeterministicTypes = Union{Deterministic, MLJBase.MLJModelInterface.DeterministicDetector}

mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deterministic
mutable struct DeterministicTunedModel{T,M<:DeterministicTypes,L} <: MLJBase.Deterministic
model::M
tuning::T # tuning strategy
resampling # resampling strategy
Expand All @@ -51,9 +51,10 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
check_measure::Bool
cache::Bool
compact_history::Bool
logger::L
end

mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Probabilistic
mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes,L} <: MLJBase.Probabilistic
model::M
tuning::T # tuning strategy
resampling # resampling strategy
Expand All @@ -71,10 +72,11 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
check_measure::Bool
cache::Bool
compact_history::Bool
logger::L
end

const EitherTunedModel{T,M} =
Union{DeterministicTunedModel{T,M},ProbabilisticTunedModel{T,M}}
const EitherTunedModel{T,M,L} =
Union{DeterministicTunedModel{T,M,L},ProbabilisticTunedModel{T,M,L}}

MLJBase.caches_data_by_default(::Type{<:EitherTunedModel}) = false

Expand Down Expand Up @@ -279,6 +281,7 @@ function TunedModel(
check_measure=true,
cache=true,
compact_history=true,
logger=nothing
)

# user can specify model as argument instead of kwarg:
Expand Down Expand Up @@ -342,6 +345,9 @@ function TunedModel(
# get the tuning type parameter:
T = typeof(tuning)

# get the logger type parameter:
L = typeof(logger)

args = (
model,
tuning,
Expand All @@ -360,12 +366,13 @@ function TunedModel(
check_measure,
cache,
compact_history,
logger
)

if M <: DeterministicTypes
tuned_model = DeterministicTunedModel{T,M}(args...)
tuned_model = DeterministicTunedModel{T,M,L}(args...)
elseif M <: ProbabilisticTypes
tuned_model = ProbabilisticTunedModel{T,M}(args...)
tuned_model = ProbabilisticTunedModel{T,M,L}(args...)
else
throw(ERR_MODEL_TYPE)
end
Expand Down Expand Up @@ -591,7 +598,7 @@ function assemble_events!(metamodels,
end
end
# One resampling_machine per task
machs = [resampling_machine,
machs = [resampling_machine,
[machine(Resampler(
model= resampling_machine.model.model,
resampling = resampling_machine.model.resampling,
Expand All @@ -603,9 +610,9 @@ function assemble_events!(metamodels,
repeats = resampling_machine.model.repeats,
acceleration = resampling_machine.model.acceleration,
cache = resampling_machine.model.cache,
compact = resampling_machine.model.compact
), resampling_machine.args...; cache=false) for
_ in 2:length(partitions)]...]
compact = resampling_machine.model.compact,
logger = resampling_machine.model.logger),
resampling_machine.args...; cache=false) for _ in 2:length(partitions)]...]

@sync for (i, parts) in enumerate(partitions)
Threads.@spawn begin
Expand Down Expand Up @@ -740,8 +747,8 @@ function finalize(tuned_model,
return fitresult, meta_state, report
end

function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
verbosity::Integer, data...) where {T,M}
function MLJBase.fit(tuned_model::EitherTunedModel{T,M,L},
verbosity::Integer, data...) where {T,M,L}
tuning = tuned_model.tuning
model = tuned_model.model
_range = tuned_model.range
Expand Down Expand Up @@ -769,6 +776,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
acceleration = tuned_model.acceleration_resampling,
cache = tuned_model.cache,
compact = tuned_model.compact_history,
logger = tuned_model.logger
)
resampling_machine = machine(resampler, data...; cache=false)
history, state = build!(nothing, n, tuning, model, model_buffer, state,
Expand Down Expand Up @@ -900,9 +908,9 @@ end
## METADATA

MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true
MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M =
MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} =
MLJBase.supports_weights(M)
MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M =
MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} =
MLJBase.supports_class_weights(M)
MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) =
"MLJTuning.ProbabilisticTunedModel"
Expand All @@ -914,9 +922,9 @@ MLJBase.package_uuid(::Type{<:EitherTunedModel}) =
MLJBase.package_url(::Type{<:EitherTunedModel}) =
"https://github.com/alan-turing-institute/MLJTuning.jl"
MLJBase.package_license(::Type{<:EitherTunedModel}) = "MIT"
MLJBase.is_pure_julia(::Type{<:EitherTunedModel{T,M}}) where {T,M} =
MLJBase.is_pure_julia(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} =
MLJBase.is_pure_julia(M)
MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M}}) where {T,M} =
MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} =
MLJBase.input_scitype(M)
MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M}}) where {T,M} =
MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} =
MLJBase.target_scitype(M)

0 comments on commit dc1d6d4

Please sign in to comment.