Skip to content

Commit

Permalink
Add Basin / subgrid_time table (#1975)
Browse files Browse the repository at this point in the history
Fixes #1010

This adds `Basin / subgrid_time`. So far the only relation we could
update over time was `Q(h)` (`TabulatedRatingCurve / time`), and that is
implemented differently. I wrote #1976 to get those more in line. It's
good to read that since it explains the implementation here.

Things I dislike:
- Need a special case to allow an underscore in `Basin / subgrid_time`,
just like `Basin / concentration_`.
- Other dynamic tables are just named time and have a static counterpart
like `Basin / static`, `Basin / time`. Since we already have `Basin /
subgrid` we cannot do that.

---------

Co-authored-by: Maarten Pronk <[email protected]>
  • Loading branch information
visr and evetion authored Dec 20, 2024
1 parent ac0920d commit aecf57f
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 67 deletions.
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using PreallocationTools: LazyBufferCache
# basin profiles and TabulatedRatingCurve. See also the node
# references in the docs.
using DataInterpolations:
ConstantInterpolation,
LinearInterpolation,
LinearInterpolationIntInv,
invert_integral,
Expand Down
22 changes: 19 additions & 3 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,28 @@ function apply_parameter_update!(parameter_update)::Nothing
end

function update_subgrid_level!(integrator)::Nothing
(; p) = integrator
(; p, t) = integrator
du = get_du(integrator)
basin_level = p.basin.current_properties.current_level[parent(du)]
subgrid = integrator.p.subgrid
for (i, (index, interp)) in enumerate(zip(subgrid.basin_index, subgrid.interpolations))
subgrid.level[i] = interp(basin_level[index])

# First update the all the subgrids with static h(h) relations
for (level_index, basin_index, hh_itp) in zip(
subgrid.level_index_static,
subgrid.basin_index_static,
subgrid.interpolations_static,
)
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
# Then update the subgrids with dynamic h(h) relations
for (level_index, basin_index, lookup) in zip(
subgrid.level_index_time,
subgrid.basin_index_time,
subgrid.current_interpolation_index,
)
itp_index = lookup(t)
hh_itp = subgrid.interpolations_time[itp_index]
subgrid.level[level_index] = hh_itp(basin_level[basin_index])
end
end

Expand Down
31 changes: 28 additions & 3 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ end

Base.to_index(id::NodeID) = Int(id.value)

"LinearInterpolation from a Float64 to a Float64"
const ScalarInterpolation = LinearInterpolation{
Vector{Float64},
Vector{Float64},
Expand All @@ -105,6 +106,10 @@ const ScalarInterpolation = LinearInterpolation{
(1,),
}

"ConstantInterpolation from a Float64 to an Int, used to look up indices over time"
const IndexLookup =
ConstantInterpolation{Vector{Int64}, Vector{Float64}, Vector{Float64}, Int64, (1,)}

set_zero!(v) = v .= zero(eltype(v))
const Cache = LazyBufferCache{Returns{Int}, typeof(set_zero!)}

Expand Down Expand Up @@ -867,10 +872,30 @@ end

"Subgrid linearly interpolates basin levels."
@kwdef struct Subgrid
subgrid_id::Vector{Int32}
basin_index::Vector{Int32}
interpolations::Vector{ScalarInterpolation}
# current level of each subgrid (static and dynamic) ordered by subgrid_id
level::Vector{Float64}

# Static part
# Static subgrid ids
subgrid_id_static::Vector{Int32}
# index into the basin.current_level vector for each static subgrid_id
basin_index_static::Vector{Int}
# index into the subgrid.level vector for each static subgrid_id
level_index_static::Vector{Int}
# per subgrid one relation
interpolations_static::Vector{ScalarInterpolation}

# Dynamic part
# Dynamic subgrid ids
subgrid_id_time::Vector{Int32}
# index into the basin.current_level vector for each dynamic subgrid_id
basin_index_time::Vector{Int}
# index into the subgrid.level vector for each dynamic subgrid_id
level_index_time::Vector{Int}
# per subgrid n relations, n being the number of timesteps for that subgrid
interpolations_time::Vector{ScalarInterpolation}
# per subgrid 1 lookup from t to an index in interpolations_time
current_interpolation_index::Vector{IndexLookup}
end

"""
Expand Down
175 changes: 146 additions & 29 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,22 @@ function parse_static_and_time(
return out, !errors
end

"""
Retrieve and validate the split of node IDs between static and time tables.
For node types that can have a part of the parameters defined statically and a part dynamically,
this checks if each ID is defined exactly once in either table.
The `is_complete` argument allows disabling the check that all Node IDs of type `node_type`
are either in the `static` or `time` table.
This is not required for Subgrid since not all Basins need to have subgrids.
"""
function static_and_time_node_ids(
db::DB,
static::StructVector,
time::StructVector,
node_type::NodeType.T,
node_type::NodeType.T;
is_complete::Bool = true,
)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool}
node_ids = get_node_ids(db, node_type)
ids = Int32.(node_ids)
Expand All @@ -205,7 +216,7 @@ function static_and_time_node_ids(
errors = true
@error "$node_type cannot be in both static and time tables, found these node IDs in both: $doubles."
end
if !issetequal(node_ids, union(static_node_ids, time_node_ids))
if is_complete && !issetequal(node_ids, union(static_node_ids, time_node_ids))
errors = true
@error "$node_type node IDs don't match."
end
Expand Down Expand Up @@ -1227,46 +1238,152 @@ function FlowDemand(db::DB, config::Config)::FlowDemand
)
end

function push_lookup!(
current_interpolation_index::Vector{IndexLookup},
lookup_index::Vector{Int},
lookup_time::Vector{Float64},
)
index_lookup = ConstantInterpolation(
lookup_index,
lookup_time;
extrapolate = true,
cache_parameters = true,
)
push!(current_interpolation_index, index_lookup)
end

function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id))
tables = load_structvector(db, config, BasinSubgridV1)
node_table = get_node_ids(db, NodeType.Basin)
time = load_structvector(db, config, BasinSubgridTimeV1)
static = load_structvector(db, config, BasinSubgridV1)

subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
has_error = false
for group in IterTools.groupby(row -> row.subgrid_id, tables)
# Since not all Basins need to have subgrids, don't enforce completeness.
_, _, _, valid =
static_and_time_node_ids(db, static, time, NodeType.Basin; is_complete = false)
if !valid
error("Problems encountered when parsing Subgrid static and time node IDs.")
end

node_to_basin = Dict{Int32, Int}(
Int32(node_id) => index for (index, node_id) in enumerate(basin.node_id)
)
subgrid_id_static = Int32[]
basin_index_static = Int[]
interpolations_static = ScalarInterpolation[]

# In the static table, each subgrid ID has 1 h(h) relation. We process one relation
# at a time and push the results to the respective vectors.
for group in IterTools.groupby(row -> row.subgrid_id, static)
subgrid_id = first(getproperty.(group, :subgrid_id))
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), node_table)
node_id = first(getproperty.(group, :node_id))
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)
!is_valid && error("Invalid Basin / subgrid table.")

# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_id_static, subgrid_id)
push!(basin_index_static, node_to_basin[node_id])
push!(interpolations_static, hh_itp)
end

if is_valid
# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
new_interp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
push!(subgrid_ids, subgrid_id)
push!(basin_index, node_to_basin[node_id])
push!(interpolations, new_interp)
else
has_error = true
subgrid_id_time = Int32[]
basin_index_time = Int[]
interpolations_time = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]

# Push the first subgrid_id and basin_index
if length(time) > 0
push!(subgrid_id_time, first(time.subgrid_id))
push!(basin_index_time, node_to_basin[first(time.node_id)])
end

# Initialize index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]

interpolation_index = 0
# In the time table, each subgrid ID can have a different number of relations over time.
# We group over the combination of subgrid ID and time such that this group has 1 h(h) relation.
# We process one relation at a time and push the results to the respective vectors.
# Some vectors are pushed only when the subgrid_id has changed. This can be done in
# sequence since it is first sorted by subgrid_id and then by time.
for group in IterTools.groupby(row -> (row.subgrid_id, row.time), time)
interpolation_index += 1
subgrid_id = first(getproperty.(group, :subgrid_id))
time_group = seconds_since(first(getproperty.(group, :time)), config.starttime)
node_id = first(getproperty.(group, :node_id))
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

is_valid =
valid_subgrid(subgrid_id, node_id, node_to_basin, basin_level, subgrid_level)
!is_valid && error("Invalid Basin / subgrid_time table.")

# Ensure it doesn't extrapolate before the first value.
pushfirst!(subgrid_level, first(subgrid_level))
pushfirst!(basin_level, nextfloat(-Inf))
hh_itp = LinearInterpolation(
subgrid_level,
basin_level;
extrapolate = true,
cache_parameters = true,
)
# These should only be pushed when the subgrid_id has changed
if subgrid_id_time[end] != subgrid_id
# Push the completed index_lookup of the previous subgrid_id
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
# Push the new subgrid_id and basin_index
push!(subgrid_id_time, subgrid_id)
push!(basin_index_time, node_to_basin[node_id])
# Start new index_lookup contents
lookup_time = Float64[]
lookup_index = Int[]
end
push!(lookup_index, interpolation_index)
push!(lookup_time, time_group)
push!(interpolations_time, hh_itp)
end

# Push completed IndexLookup of the last group
if interpolation_index > 0
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
end

has_error && error("Invalid Basin / subgrid table.")
level = fill(NaN, length(subgrid_ids))
level = fill(NaN, length(subgrid_id_static) + length(subgrid_id_time))

return Subgrid(; subgrid_id = subgrid_ids, basin_index, interpolations, level)
# Find the level indices
level_index_static = zeros(Int, length(subgrid_id_static))
level_index_time = zeros(Int, length(subgrid_id_time))
subgrid_ids = sort(vcat(subgrid_id_static, subgrid_id_time))
for (i, subgrid_id) in enumerate(subgrid_id_static)
level_index_static[i] = findsorted(subgrid_ids, subgrid_id)
end
for (i, subgrid_id) in enumerate(subgrid_id_time)
level_index_time[i] = findsorted(subgrid_ids, subgrid_id)
end

return Subgrid(;
level,
subgrid_id_static,
basin_index_static,
level_index_static,
interpolations_static,
subgrid_id_time,
basin_index_time,
level_index_time,
interpolations_time,
current_interpolation_index,
)
end

function Allocation(db::DB, config::Config, graph::MetaGraph)::Allocation
Expand Down
12 changes: 12 additions & 0 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
@schema "ribasim.basin.profile" BasinProfile
@schema "ribasim.basin.state" BasinState
@schema "ribasim.basin.subgrid" BasinSubgrid
@schema "ribasim.basin.subgridtime" BasinSubgridTime
@schema "ribasim.basin.concentration" BasinConcentration
@schema "ribasim.basin.concentrationexternal" BasinConcentrationExternal
@schema "ribasim.basin.concentrationstate" BasinConcentrationState
Expand Down Expand Up @@ -58,8 +59,11 @@ function nodetype(
type_string = string(T)
elements = split(type_string, '.'; limit = 3)
last_element = last(elements)
# Special case last elements that need an underscore
if startswith(last_element, "concentration") && length(last_element) > 13
elements[end] = "concentration_$(last_element[14:end])"
elseif last_element == "subgridtime"
elements[end] = "subgrid_time"
end
if isnode(sv)
n = elements[2]
Expand Down Expand Up @@ -150,6 +154,14 @@ end
subgrid_level::Float64
end

@version BasinSubgridTimeV1 begin
subgrid_id::Int32
node_id::Int32
time::DateTime
basin_level::Float64
subgrid_level::Float64
end

@version LevelBoundaryStaticV1 begin
node_id::Int32
active::Union{Missing, Bool}
Expand Down
4 changes: 2 additions & 2 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ Validate the entries for a single subgrid element.
"""
function valid_subgrid(
subgrid_id::Int32,
node_id::NodeID,
node_to_basin::Dict{NodeID, Int},
node_id::Int32,
node_to_basin::Dict{Int32, Int},
basin_level::Vector{Float64},
subgrid_level::Vector{Float64},
)::Bool
Expand Down
11 changes: 7 additions & 4 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,14 @@ function subgrid_level_table(
(; t, saveval) = saved.subgrid_level
subgrid = integrator.p.subgrid

nelem = length(subgrid.subgrid_id)
nelem = length(subgrid.level)
ntsteps = length(t)

time = repeat(datetime_since.(t, config.starttime); inner = nelem)
subgrid_id = repeat(subgrid.subgrid_id; outer = ntsteps)
subgrid_id = repeat(
sort(vcat(subgrid.subgrid_id_static, subgrid.subgrid_id_time));
outer = ntsteps,
)
subgrid_level = FlatVector(saveval)
return (; time, subgrid_id, subgrid_level)
end
Expand Down Expand Up @@ -412,9 +415,9 @@ function write_arrow(
mkpath(dirname(path))
try
Arrow.write(path, table; compress, metadata)
catch
catch e
@error "Failed to write results, file may be locked." path
error("Failed to write results.")
rethrow(e)
end
return nothing
end
Expand Down
Loading

0 comments on commit aecf57f

Please sign in to comment.