From 31898d9d2290ad6c3e7998657015eaa3a87ab978 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 27 May 2024 18:51:03 +1200 Subject: [PATCH] add a default_logger --- src/init.jl | 1 + src/machines.jl | 18 +++++++++++++ src/resampling.jl | 69 +++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 83 insertions(+), 5 deletions(-) diff --git a/src/init.jl b/src/init.jl index 947bdc03..d85fede1 100644 --- a/src/init.jl +++ b/src/init.jl @@ -3,6 +3,7 @@ function __init__() global DEFAULT_RESOURCE = Ref{AbstractResource}(CPU1()) global DEFAULT_SCITYPE_CHECK_LEVEL = Ref{Int}(1) global SHOW_COLOR = Ref{Bool}(true) + global DEFAULT_LOGGER = Ref{Any}(nothing) # for testing asynchronous training of learning networks: global TESTING = parse(Bool, get(ENV, "TEST_MLJBASE", "false")) diff --git a/src/machines.jl b/src/machines.jl index 1a3f5388..c08fba9d 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -1088,6 +1088,24 @@ function save(file::Union{String,IO}, mach::Machine) serialize(file, smach) end +const ERR_INVALID_DEFAULT_LOGGER = ArgumentError( + "`default_logger()` is currently `nothing`. "* + "Either specify an explicit path or stream as "* + "target of the save, or use `default_logger(logger)` "* + "to change the default logger. " +) + +""" + MLJ.save(mach) + MLJBase.save(mach) + +Save the current machine as an artifact at the location associated with +`default_logger`](@ref). + +""" +MLJBase.save(mach::Machine) = MLJBase.save(default_logger(), mach) +MLJBase.save(::Nothing, ::Machine) = throw(ERR_INVALID_DEFAULT_LOGGER) + report_for_serialization(mach) = mach.report # NOTE. there is also a specialization of `report_for_serialization` for `Composite` diff --git a/src/resampling.jl b/src/resampling.jl index 250e3ca0..6549cb96 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -1,4 +1,4 @@ -# # TYPE ALIASES + # TYPE ALIASES const AbstractRow = Union{AbstractVector{<:Integer}, Colon} const TrainTestPair = Tuple{AbstractRow,AbstractRow} @@ -747,6 +747,64 @@ Base.show(io::IO, e::CompactPerformanceEvaluation) = print(io, "CompactPerformanceEvaluation$(_summary(e))") + +# =============================================================== +## USER CONTROL OF DEFAULT LOGGING + +const DOC_DEFAULT_LOGGER = + """ + + The default logger is used in calls to [`evaluate!`](@ref) and [`evaluate`](@ref), and + in the constructors `TunedModel` and `IteratedModel`, unless the `logger` keyword is + explicitly specified. + + !!! note + + In MLJ version prior to 0.21 the default logger is always `nothing`. + +""" + +""" + default_logger() + +Return the current value of the default logger for use with supported machine learning +tracking platforms, such as [MLflow](https://mlflow.org/docs/latest/index.html). + +$DOC_DEFAULT_LOGGER + + When MLJBase is first loaded, the default logger is `nothing`. To reset the logger, see + beow. + +""" +default_logger() = DEFAULT_LOGGER[] + +""" + default_logger(logger) + +Reset the default logger. + +# Example + +Suppose an [MLflow](https://mlflow.org/docs/latest/index.html) tracking service is running +on a local server at `http://127.0.0.1:500`. Then every in every `evaluate` call in which +`logger` is not specified, as in the example below, the peformance evaluation is +automatically logged to the service. + +```julia-repl +using MLJ +logger = MLJFlow.Logger("http://127.0.0.1:5000/api") +default_logger(logger) + +X, y = make_moons() +model = ConstantClassifier() +evaluate(model, X, y, measures=[log_loss, accuracy)]) + +""" +function default_logger(logger) + DEFAULT_LOGGER[] = logger +end + + # =============================================================== ## EVALUATION METHODS @@ -1068,7 +1126,8 @@ Although `evaluate!` is mutating, `mach.model` and `mach.args` are not mutated. `false` the `per_observation` field of the returned object is populated with `missing`s. Setting to `false` may reduce compute time and allocations. -- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref)) +- `logger=default_logger()` - a logger object for forwarding results to a machine learning + tracking platform; see [`default_logger`](@ref) for details. - `compact=false` - if `true`, the returned evaluation object excludes these fields: `fitted_params_per_fold`, `report_per_fold`, `train_test_rows`. @@ -1093,7 +1152,7 @@ function evaluate!( check_measure=true, per_observation=true, verbosity=1, - logger=nothing, + logger=default_logger(), compact=false, ) @@ -1544,7 +1603,7 @@ end acceleration=default_resource(), check_measure=true, per_observation=true, - logger=nothing, + logger=default_logger(), compact=false, ) @@ -1632,7 +1691,7 @@ function Resampler( repeats=1, cache=true, per_observation=true, - logger=nothing, + logger=default_logger(), compact=false, ) resampler = Resampler(