Skip to content

Commit

Permalink
Update 'Allow Optim.Options to pass to BAT.find_mode' to current main…
Browse files Browse the repository at this point in the history
… branch
  • Loading branch information
Micki-D committed May 2, 2024
1 parent 67ac2bf commit dc9112e
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*.jl.mem
.ipynb_checkpoints
Manifest.toml
.vscode/settings.json
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"
Expand All @@ -82,10 +84,12 @@ BATFoldsExt = ["Folds", "Transducers"]
BATHDF5Ext = "HDF5"
BATNestedSamplersExt = "NestedSamplers"
BATOptimExt = "Optim"
BATOptimizationExt = ["Optimization", "ADTypes"]
BATPlotsExt = "Plots"
BATUltraNestExt = "UltraNest"

[compat]
ADTypes = "0.1, 0.2"
Accessors = "0.1"
Adapt = "3, 4"
AdvancedHMC = "0.5, 0.6"
Expand Down Expand Up @@ -132,6 +136,7 @@ Measurements = "2"
NamedArrays = "0.9, 0.10"
NestedSamplers = "0.8"
Optim = "0.19,0.20, 0.21, 0.22, 1"
Optimization = "3"
PDMats = "0.9, 0.10, 0.11"
ParallelProcessingTools = "0.4"
Parameters = "0.12, 0.13"
Expand All @@ -158,12 +163,14 @@ ZygoteRules = "0.2"
julia = "1.6"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09"
16 changes: 16 additions & 0 deletions docs/src/list_of_algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,22 @@ bat_findmode(target, OptimAlg(optalg = Optim.LBFGS()))

Requires the [Optim](https://github.com/JuliaNLSolvers/Optim.jl) Julia package to be loaded explicitly.

### Optimization.jl Optimization Algorithms

BAT mode finding algorithm type: [`OptimizationAlg`](@ref).

```julia
using OptimizationOptimJL

alg = OptimizationAlg(;
optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10),
maxiters=200,
kwargs=(f_calls_limit=50,)
)
bat_findmode(target, alg)
```
Requires one of the [Optimization.jl](https://github.com/SciML/Optimization.jl) packages to be loaded explicitly.


### Maximum Sample Estimator

Expand Down
1 change: 1 addition & 0 deletions docs/src/stable_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ MetropolisHastings
MHProposalDistTuning
ModeAsDefined
OptimAlg
OptimizationAlg
OrderedResampling
PosteriorMeasure
PriorSubstitution
Expand Down
101 changes: 101 additions & 0 deletions examples/dev-internal/test_findmode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
using BAT
using Optim

posterior = BAT.example_posterior()

optalg = OptimAlg(;
optalg = Optim.NelderMead(parameters=Optim.FixedParameters()),
maxiters=200,
kwargs = (f_calls_limit=100,),
)

my_mode = bat_findmode(posterior, optalg)

fieldnames(typeof(my_mode.info.res))


using BAT
#using Optim
#using Optimization
using OptimizationOptimJL

using InverseFunctions, FunctionChains, DensityInterface


posterior = BAT.example_posterior()
optalg = OptimizationAlg(; optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), maxiters=200, kwargs=(f_calls_limit=500,))
my_result = bat_findmode(posterior, optalg)

a = my_result.info

@test a.cache.solver_args.maxiters == 500

dump(a.alg)

fieldnames(typeof(a.cache.solver_args))

fieldnames(typeof(a.original.method))

my_mode.info.original




# Define a NamedTuple with keyword arguments
nt = (a=1, b=2)

# Define a function that accepts keyword arguments
function my_function(; a=0, b=0, c=0)
println("a = $a")
println("b = $b")
println("c = $c")
end

# Call the function and unpack the NamedTuple
my_function(; nt...)




context = get_batcontext()
target = posterior
transformed_density, trafo = BAT.transform_and_unshape(PriorToGaussian(), target, context)
inv_trafo = inverse(trafo)
initalg = BAT.apply_trafo_to_init(trafo, InitFromTarget())
x_init = collect(bat_initval(transformed_density, initalg, context).result)

f = fchain(inv_trafo, logdensityof(target), -)
f2 = (x, p) -> f(x)


optimization_function = Optimization.OptimizationFunction(f2, Optimization.SciMLBase.NoAD())
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)
optimization_result = Optimization.solve(optimization_problem,OptimizationOptimJL.NelderMead())


optalg = OptimizationAlg(;optalg = OptimizationOptimJL.NelderMead())
my_mode = bat_findmode(posterior, optalg)

my_mode.info.original
fieldnames(typeof(my_mode.info))

rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
f = rosenbrock


using AutoDiffOperators

b = Optimization.SciMLBase.NoAD()
supertype(typeof(b))

adm = ADModule(:ForwardDiff)

adsel = BAT.get_adselector(context)
supertype(typeof(adsel))



adm2 = convert_ad(ADTypes.AbstractADType, adm)
ADTypes.AutoForwardDiff()

optimization_function = Optimization.OptimizationFunction(f2, adm2)
23 changes: 17 additions & 6 deletions ext/BATOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:DEFAULT_OPTALG}) = Optim.
BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:NELDERMEAD_ALG}) = Optim.NelderMead()
BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:LBFGS_ALG}) = Optim.LBFGS()


struct NLSolversFG!{F,AD} <: Function
f::F
ad::AD
Expand Down Expand Up @@ -58,6 +57,19 @@ function (fg!::NLSolversFG!)(::Nothing, grad_f::AbstractVector{<:Real}, x::Abstr
return Nothing
end

function convert_options(algorithm::OptimAlg)
if !isnan(algorithm.abstol)
@warn "The option 'abstol=$(algorithm.abstol)' is not used for this algorithm."

Check warning on line 62 in ext/BATOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimExt.jl#L62

Added line #L62 was not covered by tests
end

kwargs = algorithm.kwargs

algopts = (; iterations = algorithm.maxiters, time_limit = algorithm.maxtime, f_tol = algorithm.reltol,)
algopts = (; algopts..., kwargs...)
algopts = (; algopts..., store_trace = true, extended_trace=true)

return Optim.Options(; algopts...)
end

function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context::BATContext)
transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context)
Expand All @@ -68,7 +80,8 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
optim_result = _optim_minimize(f, x_init, algorithm.optalg, context)
opts = convert_options(algorithm)
optim_result = _optim_minimize(f, x_init, algorithm.optalg, opts, context)
r_optim = Optim.MaximizationWrapper(optim_result)
transformed_mode = Optim.minimizer(r_optim.res)
result_mode = inv_trafo(transformed_mode)
Expand All @@ -80,18 +93,16 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context
(result = result_mode, result_trafo = transformed_mode, trafo = trafo, #=trace_trafo = trace_trafo,=# info = r_optim)
end

function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.ZerothOrderOptimizer, ::BATContext)
opts = Optim.Options(store_trace = true, extended_trace=true)
function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.ZerothOrderOptimizer, opts::Optim.Options, ::BATContext)
_optim_optimize(f, x_init, algorithm, opts)
end

function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.FirstOrderOptimizer, context::BATContext)
function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.FirstOrderOptimizer, opts::Optim.Options, context::BATContext)
adsel = get_adselector(context)
if adsel isa _NoADSelected
throw(ErrorException("$(nameof(typeof(algorithm))) requires an ADSelector to be specified in the BAT context"))
end
fg! = NLSolversFG!(f, adsel)
opts = Optim.Options(store_trace = true, extended_trace=true)
_optim_optimize(Optim.only_fg!(fg!), x_init, algorithm, opts)
end

Expand Down
72 changes: 72 additions & 0 deletions ext/BATOptimizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).

module BATOptimizationExt

@static if isdefined(Base, :get_extension)
import Optimization
else
import ..Optimization
end

using BAT
BAT.pkgext(::Val{:Optimization}) = BAT.PackageExtension{:Optimization}()

Check warning on line 12 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L12

Added line #L12 was not covered by tests


using Random
using DensityInterface, ChangesOfVariables, InverseFunctions, FunctionChains
using HeterogeneousComputing, AutoDiffOperators
using StructArrays, ArraysOfArrays, ADTypes

using BAT: MeasureLike

using BAT: get_context, get_adselector, _NoADSelected
using BAT: bat_initval, transform_and_unshape, apply_trafo_to_init
# using BAT: negative #deprecated?


AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg

Check warning on line 28 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L27-L28

Added lines #L27 - L28 were not covered by tests

BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()

Check warning on line 30 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L30

Added line #L30 was not covered by tests


function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
adm = convert_ad(ADTypes.AbstractADType, adsel)
optimization_function = Optimization.OptimizationFunction(f, adm)
return optimization_function

Check warning on line 36 in ext/BATOptimizationExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BATOptimizationExt.jl#L33-L36

Added lines #L33 - L36 were not covered by tests
end

function build_optimizationfunction(f, adsel::BAT._NoADSelected)
optimization_function = Optimization.OptimizationFunction(f)
return optimization_function
end


function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg, context::BATContext)
transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context)
inv_trafo = inverse(trafo)

initalg = apply_trafo_to_init(trafo, algorithm.init)
x_init = collect(bat_initval(transformed_density, initalg, context).result)

# Maximize density of original target, but run in transformed space, don't apply LADJ:
f = fchain(inv_trafo, logdensityof(target), -)
target_f = (x, p) -> f(x)

adsel = get_adselector(context)

optimization_function = build_optimizationfunction(target_f, adsel)
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)

algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)

transformed_mode = optimization_result.u
result_mode = inv_trafo(transformed_mode)

(result = result_mode, result_trafo = transformed_mode, trafo = trafo, info = optimization_result)
end



end # module BATOptimizationExt
1 change: 1 addition & 0 deletions src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ function __init__()
@require HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" include("../ext/BATHDF5Ext.jl")
@require NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" include("../ext/BATNestedSamplersExt.jl")
@require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("../ext/BATOptimExt.jl")
@require Optimization = "429524aa-4258-5aef-a3af-852621145aeb" @require ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" include("../ext/BATOptimizationExt.jl")
@require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("../ext/BATPlotsExt.jl")
@require UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" include("../ext/BATUltraNestExt.jl")
end
Expand Down
1 change: 1 addition & 0 deletions src/extdefs/extdefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ include("advancedhmc_defs.jl")
include("cuba_defs.jl")
include("nestedsamplers_defs.jl")
include("optim_defs.jl")
include("optimization_defs.jl")
include("ultranest_defs.jl")
5 changes: 5 additions & 0 deletions src/extdefs/optim_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@ $(TYPEDFIELDS)
optalg::ALG = ext_default(pkgext(Val(:Optim)), Val(:DEFAULT_OPTALG))
trafo::TR = PriorToGaussian()
init::IA = InitFromTarget()
maxiters::Int = 1_000
maxtime::Float64 = NaN
abstol::Float64 = NaN
reltol::Float64 = 0.0
kwargs::NamedTuple = (;)
end
export OptimAlg
36 changes: 36 additions & 0 deletions src/extdefs/optimization_defs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).

"""
OptimizationAlg
Selects an optimization algorithm from the
[Optimization.jl](https://github.com/SciML/Optimization.jl)
package.
Note that when using first order algorithms like `OptimizationOptimJL.LBFGS`, your
[`BATContext`](@ref) needs to include an `ADSelector` that specifies
which automatic differentiation backend should be used.
Constructors:
* ```$(FUNCTIONNAME)(; fields...)```
`optalg` must be an `Optimization.AbstractOptimizer`.
The field `kwargs` can be used to pass additional keywords to the optimizers
See the [Optimization.jl documentation](https://docs.sciml.ai/Optimization/stable/) for the available keyword arguments.
Fields:
$(TYPEDFIELDS)
!!! note
This algorithm is only available if the `Optimization` package or any of its submodules, like `OptimizationOptimJL`, is loaded (e.g. via
`import Optimization`).
"""
@with_kw struct OptimizationAlg{
ALG,
TR<:AbstractTransformTarget,
IA<:InitvalAlgorithm
} <: AbstractModeEstimator
optalg::ALG = ext_default(pkgext(Val(:Optimization)), Val(:DEFAULT_OPTALG))
trafo::TR = PriorToGaussian()
init::IA = InitFromTarget()
maxiters::Int64 = 1_000
maxtime::Float64 = NaN
abstol::Float64 = NaN
reltol::Float64 = 0.0
kwargs::NamedTuple = (;)
end
export OptimizationAlg
1 change: 1 addition & 0 deletions src/samplers/importance/importance_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ function bat_sample_impl(

est_integral = mean(weights)
# ToDo: Add integral error estimate
# @show samples #disable for testing

transformed_smpls = DensitySampleVector(samples, logvals, weight = weights)
smpls = inverse(trafo).(transformed_smpls)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ParallelProcessingTools = "8e8a01fc-6193-5ca1-a2f1-20776dae4199"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
Loading

0 comments on commit dc9112e

Please sign in to comment.