Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up initialization #1977

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ using StructArrays: StructVector

# OrderedSet is used to store the order of the substances in the network.
# OrderedDict is used to store the order of the sources in a subnetwork.
using DataStructures: OrderedSet, OrderedDict
using DataStructures: OrderedSet, OrderedDict, counter, inc!

export libribasim

Expand Down
7 changes: 4 additions & 3 deletions core/src/graph.jl
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is responsible for the speedup, going from n database queries to 1.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and data of edges (EdgeMetadata):
[`EdgeMetadata`](@ref)
"""
function create_graph(db::DB, config::Config)::MetaGraph
node_table = get_node_ids(db)
node_rows = execute(
db,
"SELECT node_id, node_type, subnetwork_id FROM Node ORDER BY node_type, node_id",
Expand Down Expand Up @@ -40,7 +41,7 @@ function create_graph(db::DB, config::Config)::MetaGraph
graph_data = nothing,
)
for row in node_rows
node_id = NodeID(row.node_type, row.node_id, db)
node_id = NodeID(row.node_type, row.node_id, node_table)
# Process allocation network ID
if ismissing(row.subnetwork_id)
subnetwork_id = 0
Expand All @@ -63,8 +64,8 @@ function create_graph(db::DB, config::Config)::MetaGraph
catch
error("Invalid edge type $edge_type.")
end
id_src = NodeID(from_node_type, from_node_id, db)
id_dst = NodeID(to_node_type, to_node_id, db)
id_src = NodeID(from_node_type, from_node_id, node_table)
id_dst = NodeID(to_node_type, to_node_id, node_table)
edge_metadata =
EdgeMetadata(; id = edge_id, type = edge_type, edge = (id_src, id_dst))
if edge_type == EdgeType.flow
Expand Down
56 changes: 24 additions & 32 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ function NodeType.T(s::Symbol)::NodeType.T
end

NodeType.T(str::AbstractString) = NodeType.T(Symbol(str))
NodeType.T(x::NodeType.T) = x
Base.convert(::Type{NodeType.T}, x::String) = NodeType.T(x)
Base.convert(::Type{NodeType.T}, x::Symbol) = NodeType.T(x)

SQLite.esc_id(x::NodeType.T) = esc_id(string(x))

"""
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, idx::Int)
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, db::DB)
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, p::Parameters)
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, node_ids::Vector{NodeID})

NodeID is a unique identifier for a node in the model, as well as an index into the internal node type struct.

Expand All @@ -52,42 +58,28 @@ This index can be passed directly, or calculated from the database or parameters
idx::Int
end

NodeID(type::Symbol, value::Integer, idx::Int) = NodeID(NodeType.T(type), value, idx)
NodeID(type::AbstractString, value::Integer, idx::Int) =
NodeID(NodeType.T(type), value, idx)

function NodeID(type::Union{Symbol, AbstractString}, value::Integer, db::DB)::NodeID
return NodeID(NodeType.T(type), value, db)
end

function NodeID(type::NodeType.T, value::Integer, db::DB)::NodeID
node_type_string = string(type)
# The index is equal to the number of nodes of the same type with a lower or equal ID
idx = only(
only(
execute(
columntable,
db,
"SELECT COUNT(*) FROM Node WHERE node_type == $(esc_id(node_type_string)) AND node_id <= $value",
),
),
)
if idx <= 0
error("Node ID #$value of type $type is not in the Node table.")
function NodeID(node_type, value::Integer, node_ids::Vector{NodeID})::NodeID
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Often you know the node_type, and since node_ids has it as well we can check if it is correct. Especially in test code I find NodeID(:Pump, 5, v) easier to read than NodeID(5, v). Both is possible however, we call the latter for e.g. listen_node_type handling where we don't know the type.

node_type = NodeType.T(node_type)
index = searchsortedfirst(node_ids, value; by = Int32)
if index == lastindex(node_ids) + 1
@error "Node ID $node_type #$value is not in the Node table."
error("Node ID not found")
end
node_id = node_ids[index]
if node_id.type !== node_type
@error "Requested node ID #$value is of type $(node_id.type), not $node_type"
error("Node ID is of the wrong type")
end
return NodeID(type, value, idx)
return node_id
end

function NodeID(value::Integer, db::DB)::NodeID
(idx, type) = execute(
columntable,
db,
"SELECT COUNT(*), node_type FROM Node WHERE node_type == (SELECT node_type FROM Node WHERE node_id == $value) AND node_id <= $value",
)
if only(idx) <= 0
error("Node ID #$value is not in the Node table.")
function NodeID(value::Integer, node_ids::Vector{NodeID})::NodeID
index = searchsortedfirst(node_ids, value; by = Int32)
if index == lastindex(node_ids) + 1
@error "Node ID #$value is not in the Node table."
error("Node ID not found")
end
return NodeID(only(type), value, only(idx))
return node_ids[index]
end

Base.Int32(id::NodeID) = id.value
Expand Down
99 changes: 71 additions & 28 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ function parse_static_and_time(
# of the current type
vals_out = []

node_type_string = split(string(node_type), '.')[end]
ids = get_ids(db, node_type_string)
node_ids = NodeID.(node_type_string, ids, eachindex(ids))
node_type_string = String(split(string(node_type), '.')[end])
node_ids = get_node_ids(db, node_type_string)
ids = Int32.(node_ids)
n_nodes = length(node_ids)

# Initialize the vectors for the output
Expand Down Expand Up @@ -191,14 +191,14 @@ function static_and_time_node_ids(
db::DB,
static::StructVector,
time::StructVector,
node_type::String,
node_type::NodeType.T,
)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool}
ids = get_ids(db, node_type)
node_ids = get_node_ids(db, node_type)
ids = Int32.(node_ids)
idx = searchsortedfirst.(Ref(ids), static.node_id)
static_node_ids = Set(NodeID.(Ref(node_type), static.node_id, idx))
idx = searchsortedfirst.(Ref(ids), time.node_id)
time_node_ids = Set(NodeID.(Ref(node_type), time.node_id, idx))
node_ids = NodeID.(Ref(node_type), ids, eachindex(ids))
doubles = intersect(static_node_ids, time_node_ids)
errors = false
if !isempty(doubles)
Expand Down Expand Up @@ -287,7 +287,7 @@ function TabulatedRatingCurve(
time = load_structvector(db, config, TabulatedRatingCurveTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_and_time_node_ids(db, static, time, "TabulatedRatingCurve")
static_and_time_node_ids(db, static, time, NodeType.TabulatedRatingCurve)

if !valid
error(
Expand Down Expand Up @@ -418,7 +418,8 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary
time = load_structvector(db, config, LevelBoundaryTimeV1)
concentration_time = load_structvector(db, config, LevelBoundaryConcentrationV1)

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "LevelBoundary")
_, _, node_ids, valid =
static_and_time_node_ids(db, static, time, NodeType.LevelBoundary)

if !valid
error("Problems encountered when parsing LevelBoundary static and time node IDs.")
Expand Down Expand Up @@ -452,7 +453,8 @@ function FlowBoundary(db::DB, config::Config, graph::MetaGraph)::FlowBoundary
time = load_structvector(db, config, FlowBoundaryTimeV1)
concentration_time = load_structvector(db, config, FlowBoundaryConcentrationV1)

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "FlowBoundary")
_, _, node_ids, valid =
static_and_time_node_ids(db, static, time, NodeType.FlowBoundary)

if !valid
error("Problems encountered when parsing FlowBoundary static and time node IDs.")
Expand Down Expand Up @@ -567,8 +569,8 @@ function Outlet(db::DB, config::Config, graph::MetaGraph)::Outlet
end

function Terminal(db::DB, config::Config)::Terminal
node_id = get_ids(db, "Terminal")
return Terminal(NodeID.(NodeType.Terminal, node_id, eachindex(node_id)))
node_id = get_node_ids(db, NodeType.Terminal)
return Terminal(node_id)
end

function ConcentrationData(
Expand Down Expand Up @@ -662,7 +664,7 @@ function ConcentrationData(
end

function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
node_id = get_ids(db, "Basin")
node_id = get_node_ids(db, NodeType.Basin)
n = length(node_id)

# both static and time are optional, but we need fallback defaults
Expand All @@ -683,9 +685,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin

vertical_flux = ComponentVector(; table...)

# Node IDs
node_id = NodeID.(NodeType.Basin, node_id, eachindex(node_id))

# Profiles
area, level = create_storage_tables(db, config)

Expand Down Expand Up @@ -742,9 +741,10 @@ function CompoundVariable(
weight::Float64,
look_ahead::Float64,
}[]
node_ids = get_node_ids(db)
# Each row defines a subvariable
for row in compound_variable_data
listen_node_id = NodeID(row.listen_node_id, db)
listen_node_id = NodeID(row.listen_node_id, node_ids)
# Placeholder until actual ref is known
variable_ref = PreallocationRef(placeholder_vector, 0)
variable = row.variable
Expand All @@ -757,7 +757,7 @@ function CompoundVariable(
end

# The ID of the node listening to this CompoundVariable
node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), db)
node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), node_ids)
return CompoundVariable(node_id, subvariables, greater_than)
end

Expand Down Expand Up @@ -811,8 +811,8 @@ function DiscreteControl(db::DB, config::Config, graph::MetaGraph)::DiscreteCont
condition = load_structvector(db, config, DiscreteControlConditionV1)
compound_variable = load_structvector(db, config, DiscreteControlVariableV1)

ids = get_ids(db, "DiscreteControl")
node_id = NodeID.(:DiscreteControl, ids, eachindex(ids))
node_id = get_node_ids(db, NodeType.DiscreteControl)
ids = Int32.(node_id)
compound_variables, valid =
parse_variables_and_conditions(compound_variable, condition, ids, db, graph)

Expand Down Expand Up @@ -913,8 +913,8 @@ end
function ContinuousControl(db::DB, config::Config, graph::MetaGraph)::ContinuousControl
compound_variable = load_structvector(db, config, ContinuousControlVariableV1)

ids = get_ids(db, "ContinuousControl")
node_id = NodeID.(:ContinuousControl, ids, eachindex(ids))
node_id = get_node_ids(db, NodeType.ContinuousControl)
ids = Int32.(node_id)

# Avoid using `function` as a variable name as that is recognized as a keyword
func, controlled_variable, errors = continuous_control_functions(db, config, ids)
Expand All @@ -940,7 +940,7 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl
static = load_structvector(db, config, PidControlStaticV1)
time = load_structvector(db, config, PidControlTimeV1)

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "PidControl")
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.PidControl)

if !valid
error("Problems encountered when parsing PidControl static and time node IDs.")
Expand Down Expand Up @@ -968,7 +968,8 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl
end
controlled_basins = collect(controlled_basins)

listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(db))
all_node_ids = get_node_ids(db)
listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(all_node_ids))

return PidControl(;
node_id = node_ids,
Expand Down Expand Up @@ -1087,9 +1088,9 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand
static = load_structvector(db, config, UserDemandStaticV1)
time = load_structvector(db, config, UserDemandTimeV1)
concentration_time = load_structvector(db, config, UserDemandConcentrationV1)
ids = get_ids(db, "UserDemand")

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "UserDemand")
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.UserDemand)
ids = Int32.(node_ids)

if !valid
error("Problems encountered when parsing UserDemand static and time node IDs.")
Expand Down Expand Up @@ -1229,14 +1230,15 @@ end
function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id))
tables = load_structvector(db, config, BasinSubgridV1)
node_table = get_node_ids(db, NodeType.Basin)

subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
has_error = false
for group in IterTools.groupby(row -> row.subgrid_id, tables)
subgrid_id = first(getproperty.(group, :subgrid_id))
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), db)
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), node_table)
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

Expand Down Expand Up @@ -1395,11 +1397,52 @@ function Parameters(db::DB, config::Config)::Parameters
return p
end

function get_ids(db::DB, nodetype)::Vector{Int32}
sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(nodetype)) ORDER BY node_id"
function get_node_ints(db::DB, node_type)::Vector{Int32}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This signature is very similar to get_node_ids and I almost thought you made a typo. get_node_ids_int32 or similar?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha yeah good point, done.

sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(node_type)) ORDER BY node_id"
return only(execute(columntable, db, sql))
end

function get_node_ids_types(
db::DB,
)::@NamedTuple{node_id::Vector{Int32}, node_type::Vector{NodeType.T}}
sql = "SELECT node_id, node_type FROM Node ORDER BY node_id"
table = execute(columntable, db, sql)
# convert from String to NodeType
node_type = NodeType.T.(table.node_type)
return (; table.node_id, node_type)
end

function get_node_ids(db::DB)::Vector{NodeID}
nt = get_node_ids_types(db)
node_ids = Vector{Ribasim.NodeID}(undef, length(nt.node_id))
count = counter(Ribasim.NodeType.T)
for (i, (; node_id, node_type)) in enumerate(Tables.rows(nt))
index = inc!(count, node_type)
node_ids[i] = NodeID(node_type, node_id, index)
end
return node_ids
end

# Convenience method for tests
function get_node_ids(toml_path::String)::Vector{NodeID}
cfg = Config(toml_path)
db_path = database_path(cfg)
db = SQLite.DB(db_path)
node_ids = get_node_ids(db)
close(db)
return node_ids
end

function get_node_ids(db::DB, node_type)::Vector{NodeID}
node_type = NodeType.T(node_type)
node_ints = get_node_ints(db, node_type)
node_ids = Vector{Ribasim.NodeID}(undef, length(node_ints))
for (index, node_int) in enumerate(node_ints)
node_ids[index] = NodeID(node_type, node_int, index)
end
return node_ids
end

function exists(db::DB, tablename::String)
query = execute(
db,
Expand Down
8 changes: 4 additions & 4 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ Data is matched based on the node_id, which is sorted.
"""
function set_static_value!(
table::NamedTuple,
node_id::Vector{Int32},
node_id::Vector{NodeID},
static::StructVector,
)::NamedTuple
for (i, id) in enumerate(node_id)
idx = findsorted(static.node_id, id)
idx = findsorted(static.node_id, Int32(id))
idx === nothing && continue
row = static[idx]
set_table_row!(table, row, i)
Expand All @@ -199,7 +199,7 @@ The most recent applicable data is non-NaN data for a given ID that is on or bef
"""
function set_current_value!(
table::NamedTuple,
node_id::Vector{Int32},
node_id::Vector{NodeID},
time::StructVector,
t::DateTime,
)::NamedTuple
Expand All @@ -209,7 +209,7 @@ function set_current_value!(
for (i, id) in enumerate(node_id)
for (symbol, vector) in pairs(table)
idx = findlast(
row -> row.node_id == id && !ismissing(getproperty(row, symbol)),
row -> row.node_id == Int32(id) && !ismissing(getproperty(row, symbol)),
pre_table,
)
if idx !== nothing
Expand Down
Loading
Loading