diff --git a/.gitignore b/.gitignore index 318e444db..719614c32 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ *.jl.mem .ipynb_checkpoints Manifest.toml +.vscode/settings.json diff --git a/Project.toml b/Project.toml index 5f55655fd..ebe969c69 100644 --- a/Project.toml +++ b/Project.toml @@ -65,12 +65,14 @@ ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" @@ -82,10 +84,12 @@ BATFoldsExt = ["Folds", "Transducers"] BATHDF5Ext = "HDF5" BATNestedSamplersExt = "NestedSamplers" BATOptimExt = "Optim" +BATOptimizationExt = ["Optimization", "ADTypes"] BATPlotsExt = "Plots" BATUltraNestExt = "UltraNest" [compat] +ADTypes = "0.1, 0.2" Accessors = "0.1" Adapt = "3, 4" AdvancedHMC = "0.5, 0.6" @@ -132,6 +136,7 @@ Measurements = "2" NamedArrays = "0.9, 0.10" NestedSamplers = "0.8" Optim = "0.19,0.20, 0.21, 0.22, 1" +Optimization = "3" PDMats = "0.9, 0.10, 0.11" ParallelProcessingTools = "0.4" Parameters = "0.12, 0.13" @@ -158,12 +163,14 @@ ZygoteRules = "0.2" julia = "1.6" [extras] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1" Folds = "41a02a25-b8f0-4f67-bc48-60067656b558" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" diff --git a/docs/src/list_of_algorithms.md b/docs/src/list_of_algorithms.md index 4688a3800..5bca8d7a0 100644 --- a/docs/src/list_of_algorithms.md +++ b/docs/src/list_of_algorithms.md @@ -166,6 +166,22 @@ bat_findmode(target, OptimAlg(optalg = Optim.LBFGS())) Requires the [Optim](https://github.com/JuliaNLSolvers/Optim.jl) Julia package to be loaded explicitly. +### Optimization.jl Optimization Algorithms + +BAT mode finding algorithm type: [`OptimizationAlg`](@ref). + +```julia +using OptimizationOptimJL + +alg = OptimizationAlg(; + optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), + maxiters=200, + kwargs=(f_calls_limit=50,) +) +bat_findmode(target, alg) +``` +Requires one of the [Optimization.jl](https://github.com/SciML/Optimization.jl) packages to be loaded explicitly. + ### Maximum Sample Estimator diff --git a/docs/src/stable_api.md b/docs/src/stable_api.md index e294572de..353c2d6de 100644 --- a/docs/src/stable_api.md +++ b/docs/src/stable_api.md @@ -92,6 +92,7 @@ MetropolisHastings MHProposalDistTuning ModeAsDefined OptimAlg +OptimizationAlg OrderedResampling PosteriorMeasure PriorSubstitution diff --git a/examples/dev-internal/test_findmode.jl b/examples/dev-internal/test_findmode.jl new file mode 100644 index 000000000..eb051c106 --- /dev/null +++ b/examples/dev-internal/test_findmode.jl @@ -0,0 +1,101 @@ +using BAT +using Optim + +posterior = BAT.example_posterior() + +optalg = OptimAlg(; + optalg = Optim.NelderMead(parameters=Optim.FixedParameters()), + maxiters=200, + kwargs = (f_calls_limit=100,), +) + +my_mode = bat_findmode(posterior, optalg) + +fieldnames(typeof(my_mode.info.res)) + + +using BAT +#using Optim +#using Optimization +using OptimizationOptimJL + +using InverseFunctions, FunctionChains, DensityInterface + + +posterior = BAT.example_posterior() +optalg = OptimizationAlg(; optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), maxiters=200, kwargs=(f_calls_limit=500,)) +my_result = bat_findmode(posterior, optalg) + +a = my_result.info + +@test a.cache.solver_args.maxiters == 500 + +dump(a.alg) + +fieldnames(typeof(a.cache.solver_args)) + +fieldnames(typeof(a.original.method)) + +my_mode.info.original + + + + +# Define a NamedTuple with keyword arguments +nt = (a=1, b=2) + +# Define a function that accepts keyword arguments +function my_function(; a=0, b=0, c=0) + println("a = $a") + println("b = $b") + println("c = $c") +end + +# Call the function and unpack the NamedTuple +my_function(; nt...) + + + + +context = get_batcontext() +target = posterior +transformed_density, trafo = BAT.transform_and_unshape(PriorToGaussian(), target, context) +inv_trafo = inverse(trafo) +initalg = BAT.apply_trafo_to_init(trafo, InitFromTarget()) +x_init = collect(bat_initval(transformed_density, initalg, context).result) + +f = fchain(inv_trafo, logdensityof(target), -) +f2 = (x, p) -> f(x) + + +optimization_function = Optimization.OptimizationFunction(f2, Optimization.SciMLBase.NoAD()) +optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init) +optimization_result = Optimization.solve(optimization_problem,OptimizationOptimJL.NelderMead()) + + +optalg = OptimizationAlg(;optalg = OptimizationOptimJL.NelderMead()) +my_mode = bat_findmode(posterior, optalg) + +my_mode.info.original +fieldnames(typeof(my_mode.info)) + +rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 +f = rosenbrock + + +using AutoDiffOperators + +b = Optimization.SciMLBase.NoAD() +supertype(typeof(b)) + +adm = ADModule(:ForwardDiff) + +adsel = BAT.get_adselector(context) +supertype(typeof(adsel)) + + + +adm2 = convert_ad(ADTypes.AbstractADType, adm) +ADTypes.AutoForwardDiff() + +optimization_function = Optimization.OptimizationFunction(f2, adm2) diff --git a/ext/BATOptimExt.jl b/ext/BATOptimExt.jl index 77557de86..8fa1c2a72 100644 --- a/ext/BATOptimExt.jl +++ b/ext/BATOptimExt.jl @@ -30,7 +30,6 @@ BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:DEFAULT_OPTALG}) = Optim. BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:NELDERMEAD_ALG}) = Optim.NelderMead() BAT.ext_default(::BAT.PackageExtension{:Optim}, ::Val{:LBFGS_ALG}) = Optim.LBFGS() - struct NLSolversFG!{F,AD} <: Function f::F ad::AD @@ -58,6 +57,19 @@ function (fg!::NLSolversFG!)(::Nothing, grad_f::AbstractVector{<:Real}, x::Abstr return Nothing end +function convert_options(algorithm::OptimAlg) + if !isnan(algorithm.abstol) + @warn "The option 'abstol=$(algorithm.abstol)' is not used for this algorithm." + end + + kwargs = algorithm.kwargs + + algopts = (; iterations = algorithm.maxiters, time_limit = algorithm.maxtime, f_tol = algorithm.reltol,) + algopts = (; algopts..., kwargs...) + algopts = (; algopts..., store_trace = true, extended_trace=true) + + return Optim.Options(; algopts...) +end function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context::BATContext) transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context) @@ -68,7 +80,8 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context # Maximize density of original target, but run in transformed space, don't apply LADJ: f = fchain(inv_trafo, logdensityof(target), -) - optim_result = _optim_minimize(f, x_init, algorithm.optalg, context) + opts = convert_options(algorithm) + optim_result = _optim_minimize(f, x_init, algorithm.optalg, opts, context) r_optim = Optim.MaximizationWrapper(optim_result) transformed_mode = Optim.minimizer(r_optim.res) result_mode = inv_trafo(transformed_mode) @@ -80,18 +93,16 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimAlg, context (result = result_mode, result_trafo = transformed_mode, trafo = trafo, #=trace_trafo = trace_trafo,=# info = r_optim) end -function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.ZerothOrderOptimizer, ::BATContext) - opts = Optim.Options(store_trace = true, extended_trace=true) +function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.ZerothOrderOptimizer, opts::Optim.Options, ::BATContext) _optim_optimize(f, x_init, algorithm, opts) end -function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.FirstOrderOptimizer, context::BATContext) +function _optim_minimize(f::Function, x_init::AbstractArray{<:Real}, algorithm::Optim.FirstOrderOptimizer, opts::Optim.Options, context::BATContext) adsel = get_adselector(context) if adsel isa _NoADSelected throw(ErrorException("$(nameof(typeof(algorithm))) requires an ADSelector to be specified in the BAT context")) end fg! = NLSolversFG!(f, adsel) - opts = Optim.Options(store_trace = true, extended_trace=true) _optim_optimize(Optim.only_fg!(fg!), x_init, algorithm, opts) end diff --git a/ext/BATOptimizationExt.jl b/ext/BATOptimizationExt.jl new file mode 100644 index 000000000..9c0b60d11 --- /dev/null +++ b/ext/BATOptimizationExt.jl @@ -0,0 +1,72 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +module BATOptimizationExt + +@static if isdefined(Base, :get_extension) + import Optimization +else + import ..Optimization +end + +using BAT +BAT.pkgext(::Val{:Optimization}) = BAT.PackageExtension{:Optimization}() + + +using Random +using DensityInterface, ChangesOfVariables, InverseFunctions, FunctionChains +using HeterogeneousComputing, AutoDiffOperators +using StructArrays, ArraysOfArrays, ADTypes + +using BAT: MeasureLike + +using BAT: get_context, get_adselector, _NoADSelected +using BAT: bat_initval, transform_and_unshape, apply_trafo_to_init +# using BAT: negative #deprecated? + + +AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg) +convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg + +BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead() + + +function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector) + adm = convert_ad(ADTypes.AbstractADType, adsel) + optimization_function = Optimization.OptimizationFunction(f, adm) + return optimization_function +end + +function build_optimizationfunction(f, adsel::BAT._NoADSelected) + optimization_function = Optimization.OptimizationFunction(f) + return optimization_function +end + + +function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg, context::BATContext) + transformed_density, trafo = transform_and_unshape(algorithm.trafo, target, context) + inv_trafo = inverse(trafo) + + initalg = apply_trafo_to_init(trafo, algorithm.init) + x_init = collect(bat_initval(transformed_density, initalg, context).result) + + # Maximize density of original target, but run in transformed space, don't apply LADJ: + f = fchain(inv_trafo, logdensityof(target), -) + target_f = (x, p) -> f(x) + + adsel = get_adselector(context) + + optimization_function = build_optimizationfunction(target_f, adsel) + optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init) + + algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol) + optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...) + + transformed_mode = optimization_result.u + result_mode = inv_trafo(transformed_mode) + + (result = result_mode, result_trafo = transformed_mode, trafo = trafo, info = optimization_result) +end + + + +end # module BATOptimizationExt diff --git a/src/BAT.jl b/src/BAT.jl index dea2468d8..ce83c8f35 100644 --- a/src/BAT.jl +++ b/src/BAT.jl @@ -141,6 +141,7 @@ function __init__() @require HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" include("../ext/BATHDF5Ext.jl") @require NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" include("../ext/BATNestedSamplersExt.jl") @require Optim = "429524aa-4258-5aef-a3af-852621145aeb" include("../ext/BATOptimExt.jl") + @require Optimization = "429524aa-4258-5aef-a3af-852621145aeb" @require ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" include("../ext/BATOptimizationExt.jl") @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("../ext/BATPlotsExt.jl") @require UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" include("../ext/BATUltraNestExt.jl") end diff --git a/src/extdefs/extdefs.jl b/src/extdefs/extdefs.jl index 19f7ba2f2..516a1295e 100644 --- a/src/extdefs/extdefs.jl +++ b/src/extdefs/extdefs.jl @@ -4,4 +4,5 @@ include("advancedhmc_defs.jl") include("cuba_defs.jl") include("nestedsamplers_defs.jl") include("optim_defs.jl") +include("optimization_defs.jl") include("ultranest_defs.jl") diff --git a/src/extdefs/optim_defs.jl b/src/extdefs/optim_defs.jl index 656fb2957..c993a2a2a 100644 --- a/src/extdefs/optim_defs.jl +++ b/src/extdefs/optim_defs.jl @@ -35,5 +35,10 @@ $(TYPEDFIELDS) optalg::ALG = ext_default(pkgext(Val(:Optim)), Val(:DEFAULT_OPTALG)) trafo::TR = PriorToGaussian() init::IA = InitFromTarget() + maxiters::Int = 1_000 + maxtime::Float64 = NaN + abstol::Float64 = NaN + reltol::Float64 = 0.0 + kwargs::NamedTuple = (;) end export OptimAlg diff --git a/src/extdefs/optimization_defs.jl b/src/extdefs/optimization_defs.jl new file mode 100644 index 000000000..ee7f66b54 --- /dev/null +++ b/src/extdefs/optimization_defs.jl @@ -0,0 +1,36 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +""" + OptimizationAlg +Selects an optimization algorithm from the +[Optimization.jl](https://github.com/SciML/Optimization.jl) +package. +Note that when using first order algorithms like `OptimizationOptimJL.LBFGS`, your +[`BATContext`](@ref) needs to include an `ADSelector` that specifies +which automatic differentiation backend should be used. +Constructors: +* ```$(FUNCTIONNAME)(; fields...)``` +`optalg` must be an `Optimization.AbstractOptimizer`. +The field `kwargs` can be used to pass additional keywords to the optimizers +See the [Optimization.jl documentation](https://docs.sciml.ai/Optimization/stable/) for the available keyword arguments. +Fields: +$(TYPEDFIELDS) +!!! note + This algorithm is only available if the `Optimization` package or any of its submodules, like `OptimizationOptimJL`, is loaded (e.g. via + `import Optimization`). +""" +@with_kw struct OptimizationAlg{ + ALG, + TR<:AbstractTransformTarget, + IA<:InitvalAlgorithm +} <: AbstractModeEstimator + optalg::ALG = ext_default(pkgext(Val(:Optimization)), Val(:DEFAULT_OPTALG)) + trafo::TR = PriorToGaussian() + init::IA = InitFromTarget() + maxiters::Int64 = 1_000 + maxtime::Float64 = NaN + abstol::Float64 = NaN + reltol::Float64 = 0.0 + kwargs::NamedTuple = (;) +end +export OptimizationAlg diff --git a/src/samplers/importance/importance_sampler.jl b/src/samplers/importance/importance_sampler.jl index d3e436571..b2d0e6679 100644 --- a/src/samplers/importance/importance_sampler.jl +++ b/src/samplers/importance/importance_sampler.jl @@ -62,6 +62,7 @@ function bat_sample_impl( est_integral = mean(weights) # ToDo: Add integral error estimate + # @show samples #disable for testing transformed_smpls = DensitySampleVector(samples, logvals, weight = weights) smpls = inverse(trafo).(transformed_smpls) diff --git a/test/Project.toml b/test/Project.toml index 0c78d8a1e..a5cf31000 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -28,6 +28,7 @@ MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ParallelProcessingTools = "8e8a01fc-6193-5ca1-a2f1-20776dae4199" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/test/optimization/test_mode_estimators.jl b/test/optimization/test_mode_estimators.jl index b20f92a6e..ad0720c16 100644 --- a/test/optimization/test_mode_estimators.jl +++ b/test/optimization/test_mode_estimators.jl @@ -4,7 +4,7 @@ using Test using LinearAlgebra, Distributions, StatsBase, ValueShapes, Random123, DensityInterface using UnPack, InverseFunctions using AutoDiffOperators, ForwardDiff -using Optim +using Optim, OptimizationOptimJL @testset "mode_estimators" begin prior = NamedTupleDist( @@ -71,17 +71,51 @@ using Optim end - @testset "NelderMead" begin + @testset "Optim.jl - NelderMead" begin context = BATContext(rng = Philox4x((0, 0))) test_findmode(posterior, OptimAlg(optalg = NelderMead(), trafo = DoNotTransform()), 0.01, context) end + @testset "Optim.jl with custom options" begin # checks that options are correctly passed to Optim.jl + context = BATContext(rng = Philox4x((0, 0))) + optimizer = OptimAlg(optalg = NelderMead(), trafo = DoNotTransform(), maxiters=20, maxtime=30, reltol=0.2, kwargs=(f_calls_limit=25,)) + + result = bat_findmode(posterior, optimizer, context) + @test result.info.res.iterations <= 20 + @test result.info.res.time_limit == 30 + @test result.info.res.f_reltol == 0.2 + @test result.info.res.f_calls <= 26 + + end - @testset "LBFGS" begin + @testset "Optim.jl - LBFGS" begin context = BATContext(rng = Philox4x((0, 0)), ad = ADModule(:ForwardDiff)) # Result Optim.maximize with LBFGS is not type-stable: test_findmode(posterior, OptimAlg(optalg = LBFGS(), trafo = DoNotTransform()), 0.01, inferred = false, context) test_findmode_ctx(posterior, OptimAlg(optalg = LBFGS(), trafo = DoNotTransform()), 0.01, context) end + @testset "Optimization.jl - NelderMead" begin + context = BATContext(rng = Philox4x((0, 0))) + # result is not type-stable: + test_findmode(posterior, OptimizationAlg(optalg = OptimizationOptimJL.NelderMead(), trafo = DoNotTransform()), 0.01, context, inferred = false) + end + + @testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl + context = BATContext(rng = Philox4x((0, 0))) + optimizer = OptimizationAlg(optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), maxiters=200, kwargs=(f_calls_limit=500,), trafo=DoNotTransform()) + + # result is not type-stable: + test_findmode(posterior, optimizer, 0.01, context, inferred = false) + + optimizer = OptimizationAlg(optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), + maxiters=200, maxtime=30, reltol=0.2, kwargs=(f_calls_limit=500,), trafo=DoNotTransform()) + + result = bat_findmode(posterior, optimizer, context) + @test result.info.cache.solver_args.maxiters == 200 + @test result.info.cache.solver_args.f_calls_limit == 500 + @test result.info.cache.solver_args.reltol == 0.2 + @test result.info.cache.solver_args.maxtime == 30 + @test result.info.original.method.n_particles == 10 + end end