Skip to content

Commit

Permalink
Add AbstractCostFunction in ParamEstim
Browse files Browse the repository at this point in the history
  • Loading branch information
gerlero committed Dec 25, 2023
1 parent 5ce4f22 commit 09402f8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/src/ParamEstim.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The `ParamEstim` submodule provides support for optimization-based parameter est

```@docs
ScaledSolution
AbstractCostFunction
RSSCostFunction
candidate
```
61 changes: 39 additions & 22 deletions src/ParamEstim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,42 @@ function Base.show(io::IO, sol::ScaledSolution)
end

"""
RSSCostFunction{fit_D0}(func, prob::InverseProblem[; D0tol, oi_hint])
abstract type AbstractCostFunction{fit_D0} end
Abstract cost function for parameter estimation.
# Type parameters
- `fit_D0::Bool`: whether to fit an additional constant factor `D0` that affects the diffusivity. Values
of `D0` can be found with relative efficiency without additional solver calls; so if any such constant
factors affecting the diffusivity are unknown, it is recommended not to fit those factors directly but set
`fit_D0` to `true` instead. Values of `D0` are found internally by local optimization. If `true`, the
`candidate` function will return a `ScaledSolution` that includes the found value of `D0`.
"""
abstract type AbstractCostFunction{fit_D0} end

function (cf::AbstractCostFunction)(params::AbstractVector,
::NullParameters = NullParameters())
cf(candidate(cf, params))
end

_solve(cf::AbstractCostFunction, params::AbstractVector) = _solve(cf, cf._func(params))
_solve(::AbstractCostFunction, prob::AbstractProblem) = solve(prob, verbose = false)
_solve(::AbstractCostFunction, sol::Solution) = sol

"""
candidate(cf::AbstractCostFunction, ::AbstractVector)
candidate(cf::AbstractCostFunction, ::Fronts.AbstractProblem)
candidate(cf::AbstractCostFunction, ::Fronts.Solution)
Return the candidate solution for a given cost function and parameter values, problem, or solution.
"""
candidate(cf::AbstractCostFunction, params::AbstractVector) = candidate(cf,
_solve(cf, params))
candidate(cf::AbstractCostFunction, prob::AbstractProblem) = candidate(cf, _solve(cf, prob))
candidate(::AbstractCostFunction{false}, sol::Solution) = sol

"""
RSSCostFunction{fit_D0}(func, prob::InverseProblem[; D0tol, oi_hint]) <: AbstractCostFunction
Residual sum of squares cost function for parameter estimation.
Expand Down Expand Up @@ -97,7 +132,8 @@ objective function.
If you need to know more than just the cost, call the `candidate` function instead.
"""
struct RSSCostFunction{fit_D0, _Tfunc, _Tprob, _TD0tol, _Toi_hint, _Tsorptivity}
struct RSSCostFunction{fit_D0, _Tfunc, _Tprob, _TD0tol, _Toi_hint, _Tsorptivity} <:
AbstractCostFunction{fit_D0}
_func::_Tfunc
_prob::_Tprob
_D0tol::_TD0tol
Expand Down Expand Up @@ -135,25 +171,6 @@ function (cf::RSSCostFunction)(sol::Union{Solution, ScaledSolution})
end
end

function (cf::RSSCostFunction)(params::AbstractVector, ::NullParameters = NullParameters())
cf(candidate(cf, params))
end

_solve(cf::RSSCostFunction, params::AbstractVector) = _solve(cf, cf._func(params))
_solve(::RSSCostFunction, prob::AbstractProblem) = solve(prob, verbose = false)
_solve(::RSSCostFunction, sol::Solution) = sol

"""
candidate(cf::RSSCostFunction, ::AbstractVector)
candidate(cf::RSSCostFunction, ::Fronts.AbstractProblem)
candidate(cf::RSSCostFunction, ::Fronts.Solution)
Return the candidate solution for a given cost function and parameter values, problem, or solution.
"""
candidate(cf::RSSCostFunction, params::AbstractVector) = candidate(cf, _solve(cf, params))
candidate(cf::RSSCostFunction, prob::AbstractProblem) = candidate(cf, _solve(cf, prob))
candidate(::RSSCostFunction{false}, sol::Solution) = sol

function candidate(cf::RSSCostFunction{true}, sol::Solution)
if !successful_retcode(sol)
return ScaledSolution(sol, NaN)
Expand Down Expand Up @@ -184,6 +201,6 @@ function candidate(cf::RSSCostFunction{true}, sol::Solution)
return ScaledSolution(sol, scaling.param[1])
end

export RSSCostFunction, candidate
export AbstractCostFunction, RSSCostFunction, candidate

end

0 comments on commit 09402f8

Please sign in to comment.