Skip to content

Commit

Permalink
Merge pull request #118 from PALEOtoolkit/parameter_aggregator
Browse files Browse the repository at this point in the history
Add ParameterAggregator (for parameter sensitivity studies)
  • Loading branch information
sjdaines authored Apr 2, 2024
2 parents 4a44758 + 774a72c commit 8017390
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 17 deletions.
5 changes: 5 additions & 0 deletions docs/src/Solver API.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ copyto!(dest::AbstractVector, src::VariableAggregator; dof::Int=1)
VariableAggregatorNamed
```

Aggregated collections of a subset of Parameters as a flattened Vector (eg for sensitivity studies) is provided by [`ParameterAggregator`](@ref):
```@docs
ParameterAggregator
```

## Defining CellRanges
```@meta
CurrentModule = PALEOboxes
Expand Down
81 changes: 68 additions & 13 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Infiltrator
A biogeochemical model consisting of [`Domain`](@ref)s, created from a [YAML](https://en.wikipedia.org/wiki/YAML)
configuration file using [`create_model_from_config`](@ref).
"""
Base.@kwdef mutable struct Model
Base.@kwdef mutable struct Model <: AbstractModel
name::String
config_files::Vector{String}
parameters::Dict{String, Any}
Expand Down Expand Up @@ -713,6 +713,7 @@ end

"""
do_deriv(dispatchlists, deltat::Float64=0.0)
do_deriv(dispatchlists, pa::ParameterAggregator, deltat::Float64=0.0)
Wrapper function to calculate entire derivative (initialize and do methods) in one call.
`dispatchlists` is from [`create_dispatch_methodlists`](@ref).
Expand All @@ -726,20 +727,61 @@ function do_deriv(dispatchlists, deltat::Float64=0.0)
return nothing
end

function do_deriv(dispatchlists, pa::ParameterAggregator, deltat::Float64=0.0)

dispatch_methodlist(dispatchlists.list_initialize) # assume initialize methods don't use parameters

dispatch_methodlist(dispatchlists.list_do, pa, deltat)

return nothing
end

"""
dispatch_methodlist(dl::ReactionMethodDispatchList, deltat::Float64=0.0)
dispatch_methodlist(dl::ReactionMethodDispatchList, pa::ParameterAggregator, deltat::Float64=0.0)
dispatch_methodlist(dl::ReactionMethodDispatchListNoGen, deltat::Float64=0.0)
dispatch_methodlist(dl::ReactionMethodDispatchListNoGen, pa::ParameterAggregator, deltat::Float64=0.0)
Dispatch to a list of methods.
Call a list of ReactionMethods.
# Implementation
As an optimisation, uses @generated for Type stability
As an optimisation, with `dl::ReactionMethodDispatchList` uses @generated for Type stability
and to avoid dynamic dispatch, instead of iterating over lists.
[`ReactionMethodDispatchList`](@ref) fields are Tuples hence are fully Typed, the @generated
function emits unrolled code with a function call for each Tuple element.
"""
function dispatch_methodlist(
dl::ReactionMethodDispatchListNoGen,
deltat::Float64=0.0
)

for i in eachindex(dl.methods)
call_method(dl.methods[i], dl.vardatas[i], dl.cellranges[i], deltat)
end

return nothing
end

function dispatch_methodlist(
dl::ReactionMethodDispatchListNoGen,
pa::ParameterAggregator,
deltat::Float64=0.0
)

for j in eachindex(dl.methods)
methodref = dl.methods[j]
if has_modified_parameters(pa, methodref)
call_method(methodref, get_parameters(pa, methodref), dl.vardatas[j], dl.cellranges[j], deltat)
else
call_method(methodref, dl.vardatas[j], dl.cellranges[j], deltat)
end
end

return nothing
end

@generated function dispatch_methodlist(
dl::ReactionMethodDispatchList{M, V, C},
deltat::Float64=0.0
Expand All @@ -764,27 +806,40 @@ function emits unrolled code with a function call for each Tuple element.
return ex
end

function dispatch_methodlist(
dl::ReactionMethodDispatchListNoGen,
@generated function dispatch_methodlist(
dl::ReactionMethodDispatchList{M, V, C},
pa::ParameterAggregator,
deltat::Float64=0.0
)

for i in eachindex(dl.methods)
call_method(dl.methods[i], dl.vardatas[i], dl.cellranges[i], deltat)
end
) where {M, V, C}

return nothing
# See https://discourse.julialang.org/t/manually-unroll-operations-with-objects-of-tuple/11604

ex = quote ; end # empty expression
for j=1:fieldcount(M)
push!(ex.args,
quote
if has_modified_parameters(pa, dl.methods[$j])
call_method(dl.methods[$j], get_parameters(pa, dl.methods[$j]), dl.vardatas[$j], dl.cellranges[$j], deltat)
else
call_method(dl.methods[$j], dl.vardatas[$j], dl.cellranges[$j], deltat)
end
end
)
end
push!(ex.args, quote; return nothing; end)

return ex
end

#################################
# Pretty printing
################################

"compact form"
# compact form
function Base.show(io::IO, model::Model)
print(io, "Model(config_files='", model.config_files,"', name='", model.name,"')")
end
"multiline form"
# multiline form
function Base.show(io::IO, ::MIME"text/plain", model::Model)
println(io, "Model")
println(io, "\tname='", model.name,"'")
Expand Down
6 changes: 5 additions & 1 deletion src/PALEOboxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import DataFrames
using DocStringExtensions
import OrderedCollections
import Logging
import Printf

import PrecompileTools
import TimerOutputs: @timeit, @timeit_debug
Expand All @@ -44,11 +45,14 @@ include("data/IsotopeData.jl")
include("VariableAttributes.jl")
include("VariableReaction.jl")
include("VariableDomain.jl")

include("Parameter.jl")
include("ParameterAggregator.jl")

include("ReactionMethodSorting.jl")
include("Model.jl")
include("Domain.jl")
include("CellRange.jl")
include("Parameter.jl")
include("ReactionMethod.jl")
include("Reaction.jl")
include("ReactionFactory.jl")
Expand Down
201 changes: 201 additions & 0 deletions src/ParameterAggregator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
############################################################################################
# ParameterAggregator
###########################################################################################

"""
ParameterAggregator(parfullnames::Vector{String}, model; eltype=Float64) -> ParameterAggregator
Represent a subset of model parameters given by `parfullnames` as a flattened Vector
`parfulnames` is a Vector of form `["domainname.reactionname.parname", ...]` defining a subset of
model parameters (NB: must be of type `ParDouble` or `ParDoubleVec` ie scalar or vector of Float64).
`norm_values` can be used to specify normalisation of the flattened parameter vector (defaults to 1.0).
The parameters can then be set from and copied to a flattened Vector using:
copyto!(pa::ParameterAggregator, newvalues::Vector) # set from newvalues .* norm_values
copyto!(currentvalues::Vector, pa::ParameterAggregator) # copy to currentvalues, dividing by norm_values
get_currentvalues(pa::ParameterAggregator) -> currentvalues::Vector
The subset of parameters are then defined by the `p` parameter Vector used by SciML solvers, and
combined with the full set (from the yaml file) to eg solve an ODE to enable sensitivity studies.
`eltype` can be eg a Dual number to support ForwardDiff automatic differentiation for parameter Jacobians.
"""
mutable struct ParameterAggregator{T}

# Parameter full names
parfullnames::Vector{String}

# replacement Parameters to be used (order matches parfullnames)
replacement_parameters::Vector{Union{Parameter{T, Nothing}, VecParameter{T, Nothing}}}

# indices in flattened p Vector for each replacement parameter
indices::Vector{UnitRange{Int64}}

# normalization (as flattened vector) for parameters
norm_values::Vector{Float64}

# reactions with replacement parameter values (each entry is a Dict of :par_name => index in replacement_parameters)
reactpars::Dict{AbstractReaction, Dict{Symbol, Int}}

# replacement ParametersTuple (merging replacement parameters with all reaction parameters),
# for those reactions that need parameter replacement
reactpartuples::Dict{AbstractReaction, NamedTuple}
end

# compact form
function Base.show(io::IO, pa::ParameterAggregator)
print(io, "ParameterAggregator(parfullnames='", pa.parfullnames,"', indices='", pa.indices,"')")
end
# multiline form
function Base.show(io::IO, ::MIME"text/plain", pa::ParameterAggregator)
println(io, typeof(pa))
Printf.@printf(io, "%40s%20s\n", "parfullname", "indices")
for (pfn, i) in IteratorUtils.zipstrict(pa.parfullnames, pa.indices)
Printf.@printf(io, "%40s%20s\n", pfn, string(i))
end
end




function ParameterAggregator(model::AbstractModel, parfullnames::Vector{String}; eltype=Float64)

reactpars = Dict{AbstractReaction, Dict{Symbol, Int}}()
replacement_parameters = Vector{Union{Parameter{eltype, Nothing}, VecParameter{eltype, Nothing}}}()
indices = UnitRange{Int}[]
nextidx = 1

# iterate through parfullnames and assemble lists of replacement parameters and corresponding indices in flattened vector
for (pidx, domreactpar) in enumerate(parfullnames)
domainname, reactionname, parname = split(domreactpar, ".")

react = get_reaction(model, domainname, reactionname; allow_not_found=false)
p = get_parameter(react, parname)

if p isa Parameter{Float64, Nothing}
replace_p = Parameter{eltype, Nothing}(
p.name, p.description, p.units, eltype(p.v), eltype(p.default_value), eltype[], false, p.external
)
elseif p isa VecParameter{Float64, Nothing}
replace_p = VecParameter{eltype, Nothing}(
p.name, p.description, p.units, Vector{eltype}(p.v), Vector{eltype}(p.default_value), eltype[], false, p.external
)
else
error("parameter $domreactpar $p is not a ParDouble or ParDoubleVec")
end
rparsindices = get!(reactpars, react, Dict{Symbol, Int}())
rparsindices[Symbol(parname)] = pidx

push!(replacement_parameters, replace_p)
endidx = nextidx + length(replace_p) - 1
push!(indices, nextidx:endidx)

nextidx = endidx + 1

end

# generate new ParametersTuple for those reactions that need parameter replacement
reactpartuples = Dict{AbstractReaction, NamedTuple}()
for (react, rparsindices) in reactpars
newparstuple = (haskey(rparsindices, k) ? replacement_parameters[rparsindices[k]] : v for (k, v) in pairs(react.pars))
reactpartuples[react] = NamedTuple{keys(react.pars)}(newparstuple)
end

norm_values = ones(indices[end][end])

return ParameterAggregator{eltype}(
parfullnames,
replacement_parameters,
indices,
norm_values,
reactpars,
reactpartuples,
)
end

Base.copy(pa::ParameterAggregator{old_eltype}) where {old_eltype} = copy_new_eltype(old_eltype, pa)

function copy_new_eltype(new_eltype, pa::ParameterAggregator{old_eltype}) where {old_eltype}

replacement_parameters = Vector{Union{Parameter{new_eltype, Nothing}, VecParameter{new_eltype, Nothing}}}()
for p in pa.replacement_parameters
if p isa Parameter{old_eltype, Nothing}
replace_p = Parameter{new_eltype, Nothing}(
p.name, p.description, p.units, new_eltype(p.v), new_eltype(p.default_value), new_eltype[], false, p.external
)
elseif p isa VecParameter{old_eltype, Nothing}
replace_p = VecParameter{new_eltype, Nothing}(
p.name, p.description, p.units, Vector{new_eltype}(p.v), Vector{new_eltype}(p.default_value), new_eltype[], false, p.external
)
else
error("parameter $p is not a scalar or vector parameter with eltype $old_eltype")
end
push!(replacement_parameters, replace_p)
end

# generate new ParametersTuple for those reactions that need parameter replacement
reactpartuples = Dict{AbstractReaction, NamedTuple}()
for (react, rparsindices) in pa.reactpars
newparstuple = (haskey(rparsindices, k) ? replacement_parameters[rparsindices[k]] : v for (k, v) in pairs(react.pars))
# @Infiltrator.infiltrate
reactpartuples[react] = NamedTuple{keys(react.pars)}(newparstuple)
end

pa_net = ParameterAggregator{new_eltype}(
pa.parfullnames,
replacement_parameters,
pa.indices,
pa.norm_values,
pa.reactpars,
reactpartuples,
)

return pa_net
end

# for use by solver: test whether `reaction` has modified parameters
has_modified_parameters(pa::ParameterAggregator, reaction::AbstractReaction) = haskey(pa.reactpartuples, reaction)

# for use by solver: retrieve modified parameters for `reaction`, or return `nothing` if no modified parameters
get_parameters(pa::ParameterAggregator, reaction::AbstractReaction) = get(pa.reactpartuples, reaction, nothing)

function Base.copyto!(pa::ParameterAggregator, newvalues::Vector)

lastidx = pa.indices[end][end]
lastidx == length(newvalues) ||
error("ParameterAggregator length $lastidx != length(newvalues) $(length(newvalues))")

for (p, indices) in IteratorUtils.zipstrict(pa.replacement_parameters, pa.indices)
if p isa Parameter
# p.v = only(view(newvalues, indices))
setvalue!(p, only(view(newvalues, indices)) * only(view(pa.norm_values, indices)))
elseif p isa VecParameter
p.v .= view(newvalues, indices) .* view(pa.norm_values, indices)
else
error("invalid Parameter type $p")
end
end

return pa
end

function Base.copyto!(currentvalues::Vector, pa::ParameterAggregator)

lastidx = pa.indices[end][end]
lastidx == length(currentvalues) ||
error("ParameterAggregator length $lastidx != length(currentvalues) $(length(currentvalues))")

for (p, indices) in IteratorUtils.zipstrict(pa.replacement_parameters, pa.indices)
currentvalues[indices] .= p.v ./ view(pa.norm_values, indices)
end

return currentvalues
end

function get_currentvalues(pa::ParameterAggregator{T}) where T
currentvalues = Vector{T}(undef, pa.indices[end][end])
return Base.copyto!(currentvalues, pa)
end
2 changes: 1 addition & 1 deletion src/ReactionFactory.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Printf

import InteractiveUtils

"""
Expand Down
Loading

0 comments on commit 8017390

Please sign in to comment.