Skip to content

Commit

Permalink
Merge pull request #6 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
More readme updates
  • Loading branch information
ablaom authored Feb 3, 2020
2 parents 7635c2e + cad039c commit f4e0ac0
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 50 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ learning models.
[![Build Status](https://travis-ci.com/alan-turing-institute/MLJTuning.jl.svg?branch=master)](https://travis-ci.com/alan-turing-institute/MLJTuning.jl)
[![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/MLJTuning.jl/badge.svg?branch=master)](https://coveralls.io/github/alan-turing-institute/MLJTuning.jl?branch=master)

### Contents

- [Who is this repo for?](#who-is-this-repo-for)
- [What's provided here?](#what's-provided-here)
- [How do I implement a new tuning strategy?](#How-do-I-implement-a-new-tuning-strategy)

*Note:* This component of the [MLJ
stack](https://github.com/alan-turing-institute/MLJ.jl#the-mlj-universe)
applies to MLJ versions 0.8.0 and higher. Prior to 0.8.0, tuning
Expand Down Expand Up @@ -84,7 +90,7 @@ This repository contains:
these are essentially one-dimensional grid searches


## Implementing a New Tuning Strategy
## How do I implement a new tuning strategy?

This document assumes familiarity with the [Evaluating Model
Performance](https://alan-turing-institute.github.io/MLJ.jl/dev/evaluating_model_performance/)
Expand Down Expand Up @@ -130,7 +136,7 @@ begin, on the basis of the specific strategy and a user-specified
- An *evaluation* is the value returned by some call to the
`evaluate!` method, when passed the resampling strategy (e.g.,
`CV(nfolds=9)` and performance measures specified by the user when
specifying the tuning task (e.g., `cross_entropy`,
specifying the tuning task (e.g., `cross_entropy`b,
`accuracy`). Recall that such a value is a named tuple of vectors
with keys `measure`, `measurement`, `per_fold`, and
`per_observation`. See [Evaluating Model
Expand All @@ -141,7 +147,7 @@ begin, on the basis of the specific strategy and a user-specified
value (the `per_observation` entries being recorded as
`missing`). This and other behavior can be inspected using trait
functions. Do `info(rms)` to view the trait values for the `rms`
loss, and see [Performance
loss, for example, and see [Performance
measures](https://alan-turing-institute.github.io/MLJ.jl/dev/performance_measures/)
for details.

Expand All @@ -157,7 +163,7 @@ begin, on the basis of the specific strategy and a user-specified
- A *tuning strategy* is an instance of some subtype `S <:
TuningStrategy`, the name `S` (e.g., `Grid`) indicating the tuning
algorithm to be applied. The fields of the tuning strategy - called
*hyperparameters* - are those tuning parameters specific to the
*tuning hyperparameters* - are those tuning parameters specific to the
strategy that **do not refer to specific models or specific model
hyperparameters**. So, for example, a default resolution to be used
in a grid search is a hyperparameter of `Grid`, but the resolution
Expand Down
4 changes: 2 additions & 2 deletions src/MLJTuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ module MLJTuning
## METHOD EXPORT

# defined in tuned_models.jl:
export Grid, TunedModel
export TunedModel

# defined in strategies/:
export Explicit
export Explicit, Grid

# defined in learning_curves.jl:
export learning_curve!, learning_curve
Expand Down
11 changes: 6 additions & 5 deletions src/learning_curves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ plot(curve.parameter_values,
ylab = "CV estimate of RMS error")
```
If using a `Holdout` `resampling` strategy, and the specified
hyperparameter is the number of iterations in some iterative model
(and that model has an appropriately overloaded `MLJBase.update`
method) then training is not restarted from scratch for each increment
of the parameter, ie the model is trained progressively.
If using a `Holdout()` `resampling` strategy (with no shuffling) and
if the specified hyperparameter is the number of iterations in some
iterative model (and that model has an appropriately overloaded
`MLJBase.update` method) then training is not restarted from scratch
for each increment of the parameter, ie the model is trained
progressively.
```julia
atom.lambda=200
Expand Down
10 changes: 7 additions & 3 deletions src/strategies/grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ cases all `values` of each specified `NominalRange` are exhausted. If
resolution is applied to the `NumericRange` objects that maximizes the
number of grid points, subject to the restriction that this not exceed
`goal`. Otherwise the default `resolution` and any parameter-specific
resolutions apply. In all cases the models generated are shuffled
using `rng`, unless `shuffle=false`.
resolutions apply.
In all cases the models generated are shuffled using `rng`, unless
`shuffle=false`.
See also [TunedModel](@ref), [range](@ref).
Expand Down Expand Up @@ -122,6 +124,8 @@ function default_n(tuning::Grid, user_range)
process_user_range(user_range, tuning.resolution, -1)

resolutions = adjusted_resolutions(tuning.goal, ranges, resolutions)
len(t::Tuple{NumericRange,Integer}) = length(iterator(t[1], t[2]))
len(t::Tuple{NominalRange,Integer}) = t[2]
return prod(len.(zip(ranges, resolutions)))

return prod(resolutions)
end
91 changes: 56 additions & 35 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
## TYPES AND CONSTRUCTOR

mutable struct DeterministicTunedModel{T,M<:Deterministic,R,A,AR} <: MLJBase.Deterministic
mutable struct DeterministicTunedModel{T,M<:Deterministic,A,AR} <: MLJBase.Deterministic
model::M
tuning::T # tuning strategy
resampling # resampling strategy
measure
weights::Union{Nothing,Vector{<:Real}}
operation
range::R
range
train_best::Bool
repeats::Int
n::Union{Int,Nothing}
Expand All @@ -16,14 +16,14 @@ mutable struct DeterministicTunedModel{T,M<:Deterministic,R,A,AR} <: MLJBase.Det
check_measure::Bool
end

mutable struct ProbabilisticTunedModel{T,M<:Probabilistic,R,A,AR} <: MLJBase.Probabilistic
mutable struct ProbabilisticTunedModel{T,M<:Probabilistic,A,AR} <: MLJBase.Probabilistic
model::M
tuning::T # tuning strategy
resampling # resampling strategy
measure
weights::Union{Nothing,AbstractVector{<:Real}}
operation
range::R
range
train_best::Bool
repeats::Int
n::Union{Int,Nothing}
Expand Down Expand Up @@ -76,6 +76,11 @@ specified. Query the `strategy` docstring for details. To optimize
over an explicit list `v` of models of the same type, use
`strategy=Explicit()` and specify `model=v[1]` and `range=v`.
The number of models searched is specified by `n`. If unspecified,
then `MLJTuning.default_n(tuning, range)` is used. When `n` is
increased and `fit!(mach)` called again, the old search history is
re-instated and the search continues where it left off.
If `measure` supports weights (`supports_weights(measure) == true`)
then any `weights` specified will be passed to the measure. If more
than one `measure` is specified, then only the first is optimized
Expand All @@ -102,10 +107,24 @@ model in the search) will also be passed to `measure` for evaluation.
In the case of two-parameter tuning, a Plots.jl plot of performance
estimates is returned by `plot(mach)` or `heatmap(mach)`.
Once a tuning machine `mach` has bee trained as above, one can access
the learned parameters of the best model, using
`fitted_params(mach).best_fitted_params`. Similarly, the report of
training the best model is accessed via `report(mach).best_report`.
Once a tuning machine `mach` has bee trained as above, then
`fitted_params(mach)` has these keys/values:
key | value
--------------------|--------------------------------------------------
`best_model` | optimal model instance
`best_fitted_params`| learned parameters of the optimal model
The named tuple `report(mach)` has these keys/values:
key | value
--------------------|--------------------------------------------------
`best_model` | optimal model instance
`best_result` | corresponding "result" entry in the history
`best_report` | report generated by fitting the optimal model
plus others specific to the `tuning` strategy, such as `history=...`.
"""
function TunedModel(;model=nothing,
Expand All @@ -119,13 +138,12 @@ function TunedModel(;model=nothing,
range=ranges,
train_best=true,
repeats=1,
n=default_n(tuning, range),
n=nothing,
acceleration=default_resource(),
acceleration_resampling=CPU1(),
check_measure=true)

range === nothing && error("You need to specify `range=...` unless "*
"`tuning isa Explicit`. ")
range === nothing && error("You need to specify `range=...`.")
model == nothing && error("You need to specify model=... .\n"*
"If `tuning=Explicit()`, any model in the "*
"range will do. ")
Expand Down Expand Up @@ -218,7 +236,7 @@ _length(history) = length(history)
_length(::Nothing) = 0

# builds on an existing `history` until the length is `n` or the model
# supply is exhausted(method shared by `fit` and `update`). Returns
# supply is exhausted (method shared by `fit` and `update`). Returns
# the bigger history:
function build(history, n, tuning, model::M,
state, verbosity, acceleration, resampling_machine) where M
Expand All @@ -231,8 +249,8 @@ function build(history, n, tuning, model::M,
Δj == 0 && (models_exhausted = true)
shortfall = n - Δj
if models_exhausted && shortfall > 0 && verbosity > -1
@warn "Only $j < n = $n` models evaluated.\n"*
"Model supply prematurely exhausted. "
@info "Only $j (of $n) models evaluated.\n"*
"Model supply exhausted. "
end
Δj == 0 && break
shortfall < 0 && (models = models[1:n - j])
Expand All @@ -248,10 +266,13 @@ end
function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
verbosity::Integer, data...) where {T,M}
tuning = tuned_model.tuning
n = tuned_model.n
model = tuned_model.model
range = tuned_model.range
n === Nothing && (n = default_n(tuning, range))
n = tuned_model.n === nothing ?
default_n(tuning, range) : tuned_model.n

verbosity < 1 || @info "Attempting to evaluate $n models."

acceleration = tuned_model.acceleration

state = setup(tuning, model, range, verbosity)
Expand Down Expand Up @@ -293,29 +314,29 @@ function MLJBase.update(tuned_model::EitherTunedModel, verbosity::Integer,
old_fitresult, old_meta_state, data...)

history, old_tuned_model, state, resampling_machine = old_meta_state

n = tuned_model.n
acceleration = tuned_model.acceleration

if MLJBase.is_same_except(tuned_model, old_tuned_model, :n)
tuning = tuned_model.tuning
range = tuned_model.range
model = tuned_model.model

tuning=tuned_model.tuning
model=tuned_model.model
# exclamation points are for values actually used rather than
# stored:
n! = tuned_model.n === nothing ?
default_n(tuning, range) : tuned_model.n

if tuned_model.n > old_tuned_model.n
# temporarily mutate tuned_model:
tuned_model.n = n - old_tuned_model.n
old_n! = old_tuned_model.n === nothing ?
default_n(tuning, range) : old_tuned_model.n

history = build(history, n, tuning, model, state,
verbosity, acceleration, resampling_machine)
if MLJBase.is_same_except(tuned_model, old_tuned_model, :n) &&
n! >= old_n!

verbosity < 1 || @info "Attempting to add $(n! - old_n!) models "*
"to search, bringing total to $n!. "

history = build(history, n!, tuning, model, state,
verbosity, acceleration, resampling_machine)

# restore tuned_model to original state
tuned_model.n = n
else
verbosity < 1 || @info "Number of tuning iterations `n` "*
"lowered.\nTruncating existing tuning history and "*
"retraining new best model."
end
best_model, best_result = best(tuning, history)

fitresult = machine(best_model, data...)
Expand All @@ -331,7 +352,8 @@ function MLJBase.update(tuned_model::EitherTunedModel, verbosity::Integer,

_report = merge(prereport, tuning_report(tuning, history, state))

meta_state = (history, deepcopy(tuned_model), state)
meta_state = (history, deepcopy(tuned_model), state,
resampling_machine)

return fitresult, meta_state, _report

Expand Down Expand Up @@ -374,4 +396,3 @@ MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M}}) where {T,M} =
MLJBase.input_scitype(M)
MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M}}) where {T,M} =
MLJBase.target_scitype(M)

2 changes: 1 addition & 1 deletion test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
@test map(event -> last(event).measurement[1], history) results[1:4]

tm.n=100
@test_logs (:warn, r"Only 12") fit!(mach, verbosity=0)
@test_logs (:info, r"Only 12") fit!(mach, verbosity=0)
history = MLJBase.report(mach).history
@test map(event -> last(event).measurement[1], history) results
end)
Expand Down

0 comments on commit f4e0ac0

Please sign in to comment.