Skip to content

Commit

Permalink
Merge branch 'main' into hmc_proposed_sample_handling_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D authored Feb 12, 2025
2 parents 71d6dee + bf65463 commit e389e4f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
with:
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}
file: lcov.info
files: lcov.info
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BAT"
uuid = "c0cd4b16-88b7-57fa-983b-ab80aecada7e"
version = "3.3.1"
version = "4.0.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -120,7 +120,7 @@ Folds = "0.2"
ForwardDiff = "0.10"
ForwardDiffPullbacks = "0.1.1, 0.2"
FunctionChains = "0.1.4"
Functors = "0.2, 0.3, 0.4"
Functors = "0.2, 0.3, 0.4, 0.5"
HDF5 = "0.15, 0.16, 0.17"
HeterogeneousComputing = "0.2"
HypothesisTests = "0.10, 0.11"
Expand Down
8 changes: 5 additions & 3 deletions ext/BATOptimizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ function test_bat_optimization_ext()
end

AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg)
convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg
Base.convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg

BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead()


function build_optimizationfunction(f, adsel::AutoDiffOperators.ADSelector)
adm = convert_ad(ADTypes.AbstractADType, adsel)
adm = convert(ADTypes.AbstractADType, reverse_ad_selector(adsel))
optimization_function = Optimization.OptimizationFunction(f, adm)
return optimization_function
end
Expand Down Expand Up @@ -59,7 +59,9 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg,
optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init)

algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol)
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; algopts..., algorithm.kwargs...)
# Not all algorithms support abstol, just filter all NaN-valued opts out:
filtered_algopts = NamedTuple(filter(p -> !isnan(p[2]), pairs(algopts)))
optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; filtered_algopts..., algorithm.kwargs...)

transformed_mode = optimization_result.u
result_mode = inv_trafo(transformed_mode)
Expand Down
2 changes: 1 addition & 1 deletion src/BAT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ using DomainSets: UnitInterval, UnitCube, Rectangle, FullSpace, RealNumbers

using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent, AbstractThunk, unthunk

using Functors: fmap, @functor
using Functors: fmap

# For Dual specializations:
import ForwardDiff
Expand Down
5 changes: 5 additions & 0 deletions test/optimization/test_mode_estimators.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using BAT
using Test

using AutoDiffOperators
using LinearAlgebra, Distributions, StatsBase, ValueShapes, Random123, DensityInterface
using UnPack, InverseFunctions
import ForwardDiff
Expand Down Expand Up @@ -101,6 +102,10 @@ using Optim, OptimizationOptimJL
context = BATContext(rng = Philox4x((0, 0)))
# result is not type-stable:
test_findmode(posterior, OptimizationAlg(optalg = OptimizationOptimJL.NelderMead(), pretransform = DoNotTransform()), 0.01, context, inferred = false)

context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff))
# result is not type-stable:
test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), pretransform = DoNotTransform()), 0.01, context, inferred = false)
end

@testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl
Expand Down

0 comments on commit e389e4f

Please sign in to comment.