Skip to content

Commit

Permalink
Merge pull request #2 from alan-turing-institute/explore
Browse files Browse the repository at this point in the history
rename learning_curve! -> learning_curve
  • Loading branch information
ablaom authored Jan 27, 2020
2 parents 8df9e99 + 72523fa commit dab955d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
22 changes: 11 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Hyperparameter optimization for
learning models.

[![Build Status](https://travis-ci.com/alan-turing-institute/MLJTuning.jl.svg?branch=master)](https://travis-ci.com/alan-turing-institute/MLJTuning.jl)
[![Coverage](http://codecov.io/github/alan-turing-institute/MLJTuning.jl/coverage.svg?branch=master)](http://codecov.io/github/alan-turing-institute/MLJTuning.jl?branch=master)
[![Coverage](http://coveralls.io/github/alan-turing-institute/MLJTuning.jl/coverage.svg?branch=master)](http://codecov.io/github/alan-turing-institute/MLJTuning.jl?branch=master)

*Note:* This component of the [MLJ
stack](https://github.com/alan-turing-institute/MLJ.jl#the-mlj-universe)
Expand All @@ -16,17 +16,17 @@ learning models.

## Who is this repo for?

This repository is not intended for the general MLJ user but is:

- a dependency of the
[MLJ](https://github.com/alan-turing-institute/MLJ.jl) machine
learning platform, allowing MLJ users to perform a variety of
hyperparameter optimization tasks
This repository is not intended to be directly imported by the general
MLJ user. Rather, MLJTuning is a dependency of the
[MLJ](https://github.com/alan-turing-institute/MLJ.jl) machine
learning platform, which allows MLJ users to perform a variety of
hyperparameter optimization tasks from there.

- a place for developers to integrate hyperparameter optimization
algorithms (here called *tuning strategies*) into MLJ, either
natively (by adding code to [/src/strategies](/src/strategies)) or
by importing and implementing an interface provided by this repo
MLJTUning is the place for developers to integrate hyperparameter
optimization algorithms (here called *tuning strategies*) into MLJ,
either by adding code to [/src/strategies](/src/strategies), or by
importing MLJTuning into a third-pary package and and implementing
MLJTuning's interface.

MLJTuning is a component of the [MLJ
stack](https://github.com/alan-turing-institute/MLJ.jl#the-mlj-universe)
Expand Down
36 changes: 19 additions & 17 deletions src/learning_curves.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
## LEARNING CURVES

"""
curve = learning_curve!(mach; resolution=30,
resampling=Holdout(),
repeats=1,
measure=rms,
weights=nothing,
operation=predict,
range=nothing,
acceleration=default_resource(),
acceleration_grid=CPU1(),
rngs=nothing,
rng_name=nothing)
curve = learning_curve(mach; resolution=30,
resampling=Holdout(),
repeats=1,
measure=rms,
weights=nothing,
operation=predict,
range=nothing,
acceleration=default_resource(),
acceleration_grid=CPU1(),
rngs=nothing,
rng_name=nothing)
Given a supervised machine `mach`, returns a named tuple of objects
suitable for generating a plot of performance estimates, as a function
Expand All @@ -34,7 +34,7 @@ atom = @load RidgeRegressor pkg=MultivariateStats
ensemble = EnsembleModel(atom=atom, n=1000)
mach = machine(ensemble, X, y)
r_lambda = range(ensemble, :(atom.lambda), lower=10, upper=500, scale=:log10)
curve = learning_curve!(mach; range=r_lambda, resampling=CV(), measure=mav)
curve = learning_curve(mach; range=r_lambda, resampling=CV(), measure=mav)
using Plots
plot(curve.parameter_values,
curve.measurements,
Expand All @@ -52,15 +52,15 @@ of the parameter, ie the model is trained progressively.
```julia
atom.lambda=200
r_n = range(ensemble, :n, lower=1, upper=250)
curves = learning_curve!(mach; range=r_n, verbosity=0, rng_name=:rng, rngs=3)
curves = learning_curve(mach; range=r_n, verbosity=0, rng_name=:rng, rngs=3)
plot!(curves.parameter_values,
curves.measurements,
xlab=curves.parameter_name,
ylab="Holdout estimate of RMS error")
```
"""
function learning_curve!(mach::Machine{<:Supervised};
function learning_curve(mach::Machine{<:Supervised};
resolution=30,
resampling=Holdout(),
weights=nothing,
Expand Down Expand Up @@ -176,16 +176,18 @@ function _tuning_results(rngs::AbstractVector, acceleration::CPUProcesses,
return ret
end

learning_curve!(machine::Machine, args...) =
learning_curve(machine, args...)

"""
learning_curve(model::Supervised, args...; kwargs...)
Plot a learning curve (or curves) without first constructing a
machine. Equivalent to `learing_curve!(machine(model, args...);
machine. Equivalent to `learing_curve(machine(model, args...);
kwargs...)
See [learning_curve!](@ref)
See [learning_curve](@ref)
"""
learning_curve(model::Supervised, args...; kwargs...) =
learning_curve!(machine(model, args...); kwargs...)
learning_curve(machine(model, args...); kwargs...)
8 changes: 4 additions & 4 deletions test/learning_curves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.2*rand(100);
if accel == CPU1() && VERSION > v"1.2"
curve = @test_logs((:info, r"No measure"),
(:info, r"Training"),
learning_curve!(mach; range=r_lambda,
learning_curve(mach; range=r_lambda,
acceleration=accel))
else
curve = learning_curve!(mach; range=r_lambda,
curve = learning_curve(mach; range=r_lambda,
acceleration=accel)
end
@test curve isa NamedTuple{(:parameter_name,
Expand All @@ -48,7 +48,7 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.2*rand(100);
atom.lambda=0.3
r_n = range(ensemble, :n, lower=10, upper=100)

curves = learning_curve!(mach; range=r_n, resolution=7,
curves = learning_curve(mach; range=r_n, resolution=7,
acceleration=accel,
rngs = MersenneTwister.(1:3),
rng_name=:rng)
Expand All @@ -60,7 +60,7 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.2*rand(100);
@test !(curves.measurements[1,1] curves.measurements[1,3])

# reproducibility:
curves2 = learning_curve!(mach; range=r_n, resolution=7,
curves2 = learning_curve(mach; range=r_n, resolution=7,
acceleration=accel,
rngs = 3,
rng_name=:rng)
Expand Down

0 comments on commit dab955d

Please sign in to comment.