diff --git a/core/src/Ribasim.jl b/core/src/Ribasim.jl index 6beebf38c..1c02e67b5 100644 --- a/core/src/Ribasim.jl +++ b/core/src/Ribasim.jl @@ -148,7 +148,7 @@ using StructArrays: StructVector # OrderedSet is used to store the order of the substances in the network. # OrderedDict is used to store the order of the sources in a subnetwork. -using DataStructures: OrderedSet, OrderedDict +using DataStructures: OrderedSet, OrderedDict, counter, inc! export libribasim diff --git a/core/src/graph.jl b/core/src/graph.jl index e2a3273e0..f3ea0780b 100644 --- a/core/src/graph.jl +++ b/core/src/graph.jl @@ -6,6 +6,7 @@ and data of edges (EdgeMetadata): [`EdgeMetadata`](@ref) """ function create_graph(db::DB, config::Config)::MetaGraph + node_table = get_node_ids(db) node_rows = execute( db, "SELECT node_id, node_type, subnetwork_id FROM Node ORDER BY node_type, node_id", @@ -40,7 +41,7 @@ function create_graph(db::DB, config::Config)::MetaGraph graph_data = nothing, ) for row in node_rows - node_id = NodeID(row.node_type, row.node_id, db) + node_id = NodeID(row.node_type, row.node_id, node_table) # Process allocation network ID if ismissing(row.subnetwork_id) subnetwork_id = 0 @@ -63,8 +64,8 @@ function create_graph(db::DB, config::Config)::MetaGraph catch error("Invalid edge type $edge_type.") end - id_src = NodeID(from_node_type, from_node_id, db) - id_dst = NodeID(to_node_type, to_node_id, db) + id_src = NodeID(from_node_type, from_node_id, node_table) + id_dst = NodeID(to_node_type, to_node_id, node_table) edge_metadata = EdgeMetadata(; id = edge_id, type = edge_type, edge = (id_src, id_dst)) if edge_type == EdgeType.flow diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 03b909799..072005bf7 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -31,11 +31,16 @@ function NodeType.T(s::Symbol)::NodeType.T end NodeType.T(str::AbstractString) = NodeType.T(Symbol(str)) +NodeType.T(x::NodeType.T) = x +Base.convert(::Type{NodeType.T}, x::String) = NodeType.T(x) +Base.convert(::Type{NodeType.T}, x::Symbol) = NodeType.T(x) + +SQLite.esc_id(x::NodeType.T) = esc_id(string(x)) """ NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, idx::Int) - NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, db::DB) NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, p::Parameters) + NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, node_ids::Vector{NodeID}) NodeID is a unique identifier for a node in the model, as well as an index into the internal node type struct. @@ -52,42 +57,28 @@ This index can be passed directly, or calculated from the database or parameters idx::Int end -NodeID(type::Symbol, value::Integer, idx::Int) = NodeID(NodeType.T(type), value, idx) -NodeID(type::AbstractString, value::Integer, idx::Int) = - NodeID(NodeType.T(type), value, idx) - -function NodeID(type::Union{Symbol, AbstractString}, value::Integer, db::DB)::NodeID - return NodeID(NodeType.T(type), value, db) -end - -function NodeID(type::NodeType.T, value::Integer, db::DB)::NodeID - node_type_string = string(type) - # The index is equal to the number of nodes of the same type with a lower or equal ID - idx = only( - only( - execute( - columntable, - db, - "SELECT COUNT(*) FROM Node WHERE node_type == $(esc_id(node_type_string)) AND node_id <= $value", - ), - ), - ) - if idx <= 0 - error("Node ID #$value of type $type is not in the Node table.") +function NodeID(node_type, value::Integer, node_ids::Vector{NodeID})::NodeID + node_type = NodeType.T(node_type) + index = searchsortedfirst(node_ids, value; by = Int32) + if index == lastindex(node_ids) + 1 + @error "Node ID $node_type #$value is not in the Node table." + error("Node ID not found") + end + node_id = node_ids[index] + if node_id.type !== node_type + @error "Requested node ID #$value is of type $(node_id.type), not $node_type" + error("Node ID is of the wrong type") end - return NodeID(type, value, idx) + return node_id end -function NodeID(value::Integer, db::DB)::NodeID - (idx, type) = execute( - columntable, - db, - "SELECT COUNT(*), node_type FROM Node WHERE node_type == (SELECT node_type FROM Node WHERE node_id == $value) AND node_id <= $value", - ) - if only(idx) <= 0 - error("Node ID #$value is not in the Node table.") +function NodeID(value::Integer, node_ids::Vector{NodeID})::NodeID + index = searchsortedfirst(node_ids, value; by = Int32) + if index == lastindex(node_ids) + 1 + @error "Node ID #$value is not in the Node table." + error("Node ID not found") end - return NodeID(only(type), value, only(idx)) + return node_ids[index] end Base.Int32(id::NodeID) = id.value diff --git a/core/src/read.jl b/core/src/read.jl index e99d207e0..255043371 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -32,9 +32,9 @@ function parse_static_and_time( # of the current type vals_out = [] - node_type_string = split(string(node_type), '.')[end] - ids = get_ids(db, node_type_string) - node_ids = NodeID.(node_type_string, ids, eachindex(ids)) + node_type_string = String(split(string(node_type), '.')[end]) + node_ids = get_node_ids(db, node_type_string) + ids = Int32.(node_ids) n_nodes = length(node_ids) # Initialize the vectors for the output @@ -191,14 +191,14 @@ function static_and_time_node_ids( db::DB, static::StructVector, time::StructVector, - node_type::String, + node_type::NodeType.T, )::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool} - ids = get_ids(db, node_type) + node_ids = get_node_ids(db, node_type) + ids = Int32.(node_ids) idx = searchsortedfirst.(Ref(ids), static.node_id) static_node_ids = Set(NodeID.(Ref(node_type), static.node_id, idx)) idx = searchsortedfirst.(Ref(ids), time.node_id) time_node_ids = Set(NodeID.(Ref(node_type), time.node_id, idx)) - node_ids = NodeID.(Ref(node_type), ids, eachindex(ids)) doubles = intersect(static_node_ids, time_node_ids) errors = false if !isempty(doubles) @@ -287,7 +287,7 @@ function TabulatedRatingCurve( time = load_structvector(db, config, TabulatedRatingCurveTimeV1) static_node_ids, time_node_ids, node_ids, valid = - static_and_time_node_ids(db, static, time, "TabulatedRatingCurve") + static_and_time_node_ids(db, static, time, NodeType.TabulatedRatingCurve) if !valid error( @@ -418,7 +418,8 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary time = load_structvector(db, config, LevelBoundaryTimeV1) concentration_time = load_structvector(db, config, LevelBoundaryConcentrationV1) - _, _, node_ids, valid = static_and_time_node_ids(db, static, time, "LevelBoundary") + _, _, node_ids, valid = + static_and_time_node_ids(db, static, time, NodeType.LevelBoundary) if !valid error("Problems encountered when parsing LevelBoundary static and time node IDs.") @@ -452,7 +453,8 @@ function FlowBoundary(db::DB, config::Config, graph::MetaGraph)::FlowBoundary time = load_structvector(db, config, FlowBoundaryTimeV1) concentration_time = load_structvector(db, config, FlowBoundaryConcentrationV1) - _, _, node_ids, valid = static_and_time_node_ids(db, static, time, "FlowBoundary") + _, _, node_ids, valid = + static_and_time_node_ids(db, static, time, NodeType.FlowBoundary) if !valid error("Problems encountered when parsing FlowBoundary static and time node IDs.") @@ -567,8 +569,8 @@ function Outlet(db::DB, config::Config, graph::MetaGraph)::Outlet end function Terminal(db::DB, config::Config)::Terminal - node_id = get_ids(db, "Terminal") - return Terminal(NodeID.(NodeType.Terminal, node_id, eachindex(node_id))) + node_id = get_node_ids(db, NodeType.Terminal) + return Terminal(node_id) end function ConcentrationData( @@ -662,7 +664,7 @@ function ConcentrationData( end function Basin(db::DB, config::Config, graph::MetaGraph)::Basin - node_id = get_ids(db, "Basin") + node_id = get_node_ids(db, NodeType.Basin) n = length(node_id) # both static and time are optional, but we need fallback defaults @@ -683,9 +685,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin vertical_flux = ComponentVector(; table...) - # Node IDs - node_id = NodeID.(NodeType.Basin, node_id, eachindex(node_id)) - # Profiles area, level = create_storage_tables(db, config) @@ -742,9 +741,10 @@ function CompoundVariable( weight::Float64, look_ahead::Float64, }[] + node_ids = get_node_ids(db) # Each row defines a subvariable for row in compound_variable_data - listen_node_id = NodeID(row.listen_node_id, db) + listen_node_id = NodeID(row.listen_node_id, node_ids) # Placeholder until actual ref is known variable_ref = PreallocationRef(placeholder_vector, 0) variable = row.variable @@ -757,7 +757,7 @@ function CompoundVariable( end # The ID of the node listening to this CompoundVariable - node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), db) + node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), node_ids) return CompoundVariable(node_id, subvariables, greater_than) end @@ -811,8 +811,8 @@ function DiscreteControl(db::DB, config::Config, graph::MetaGraph)::DiscreteCont condition = load_structvector(db, config, DiscreteControlConditionV1) compound_variable = load_structvector(db, config, DiscreteControlVariableV1) - ids = get_ids(db, "DiscreteControl") - node_id = NodeID.(:DiscreteControl, ids, eachindex(ids)) + node_id = get_node_ids(db, NodeType.DiscreteControl) + ids = Int32.(node_id) compound_variables, valid = parse_variables_and_conditions(compound_variable, condition, ids, db, graph) @@ -913,8 +913,8 @@ end function ContinuousControl(db::DB, config::Config, graph::MetaGraph)::ContinuousControl compound_variable = load_structvector(db, config, ContinuousControlVariableV1) - ids = get_ids(db, "ContinuousControl") - node_id = NodeID.(:ContinuousControl, ids, eachindex(ids)) + node_id = get_node_ids(db, NodeType.ContinuousControl) + ids = Int32.(node_id) # Avoid using `function` as a variable name as that is recognized as a keyword func, controlled_variable, errors = continuous_control_functions(db, config, ids) @@ -940,7 +940,7 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl static = load_structvector(db, config, PidControlStaticV1) time = load_structvector(db, config, PidControlTimeV1) - _, _, node_ids, valid = static_and_time_node_ids(db, static, time, "PidControl") + _, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.PidControl) if !valid error("Problems encountered when parsing PidControl static and time node IDs.") @@ -968,7 +968,8 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl end controlled_basins = collect(controlled_basins) - listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(db)) + all_node_ids = get_node_ids(db) + listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(all_node_ids)) return PidControl(; node_id = node_ids, @@ -1087,9 +1088,9 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand static = load_structvector(db, config, UserDemandStaticV1) time = load_structvector(db, config, UserDemandTimeV1) concentration_time = load_structvector(db, config, UserDemandConcentrationV1) - ids = get_ids(db, "UserDemand") - _, _, node_ids, valid = static_and_time_node_ids(db, static, time, "UserDemand") + _, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.UserDemand) + ids = Int32.(node_ids) if !valid error("Problems encountered when parsing UserDemand static and time node IDs.") @@ -1229,6 +1230,7 @@ end function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id)) tables = load_structvector(db, config, BasinSubgridV1) + node_table = get_node_ids(db, NodeType.Basin) subgrid_ids = Int32[] basin_index = Int32[] @@ -1236,7 +1238,7 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid has_error = false for group in IterTools.groupby(row -> row.subgrid_id, tables) subgrid_id = first(getproperty.(group, :subgrid_id)) - node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), db) + node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), node_table) basin_level = getproperty.(group, :basin_level) subgrid_level = getproperty.(group, :subgrid_level) @@ -1395,11 +1397,52 @@ function Parameters(db::DB, config::Config)::Parameters return p end -function get_ids(db::DB, nodetype)::Vector{Int32} - sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(nodetype)) ORDER BY node_id" +function get_node_ids_int32(db::DB, node_type)::Vector{Int32} + sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(node_type)) ORDER BY node_id" return only(execute(columntable, db, sql)) end +function get_node_ids_types( + db::DB, +)::@NamedTuple{node_id::Vector{Int32}, node_type::Vector{NodeType.T}} + sql = "SELECT node_id, node_type FROM Node ORDER BY node_id" + table = execute(columntable, db, sql) + # convert from String to NodeType + node_type = NodeType.T.(table.node_type) + return (; table.node_id, node_type) +end + +function get_node_ids(db::DB)::Vector{NodeID} + nt = get_node_ids_types(db) + node_ids = Vector{Ribasim.NodeID}(undef, length(nt.node_id)) + count = counter(Ribasim.NodeType.T) + for (i, (; node_id, node_type)) in enumerate(Tables.rows(nt)) + index = inc!(count, node_type) + node_ids[i] = NodeID(node_type, node_id, index) + end + return node_ids +end + +# Convenience method for tests +function get_node_ids(toml_path::String)::Vector{NodeID} + cfg = Config(toml_path) + db_path = database_path(cfg) + db = SQLite.DB(db_path) + node_ids = get_node_ids(db) + close(db) + return node_ids +end + +function get_node_ids(db::DB, node_type)::Vector{NodeID} + node_type = NodeType.T(node_type) + node_ints = get_node_ids_int32(db, node_type) + node_ids = Vector{Ribasim.NodeID}(undef, length(node_ints)) + for (index, node_int) in enumerate(node_ints) + node_ids[index] = NodeID(node_type, node_int, index) + end + return node_ids +end + function exists(db::DB, tablename::String) query = execute( db, diff --git a/core/src/util.jl b/core/src/util.jl index 2eac3f267..e9c2a2bcc 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -180,11 +180,11 @@ Data is matched based on the node_id, which is sorted. """ function set_static_value!( table::NamedTuple, - node_id::Vector{Int32}, + node_id::Vector{NodeID}, static::StructVector, )::NamedTuple for (i, id) in enumerate(node_id) - idx = findsorted(static.node_id, id) + idx = findsorted(static.node_id, Int32(id)) idx === nothing && continue row = static[idx] set_table_row!(table, row, i) @@ -199,7 +199,7 @@ The most recent applicable data is non-NaN data for a given ID that is on or bef """ function set_current_value!( table::NamedTuple, - node_id::Vector{Int32}, + node_id::Vector{NodeID}, time::StructVector, t::DateTime, )::NamedTuple @@ -209,7 +209,7 @@ function set_current_value!( for (i, id) in enumerate(node_id) for (symbol, vector) in pairs(table) idx = findlast( - row -> row.node_id == id && !ismissing(getproperty(row, symbol)), + row -> row.node_id == Int32(id) && !ismissing(getproperty(row, symbol)), pre_table, ) if idx !== nothing diff --git a/core/test/allocation_test.jl b/core/test/allocation_test.jl index 3dd6c7ba0..05cb76ce3 100644 --- a/core/test/allocation_test.jl +++ b/core/test/allocation_test.jl @@ -6,11 +6,8 @@ toml_path = normpath(@__DIR__, "../../generated_testmodels/subnetwork/ribasim.toml") @test ispath(toml_path) - cfg = Ribasim.Config(toml_path) - db_path = Ribasim.database_path(cfg) - db = SQLite.DB(db_path) - p = Ribasim.Parameters(db, cfg) - close(db) + model = Ribasim.Model(toml_path) + p = model.integrator.p (; graph, allocation) = p @@ -47,8 +44,7 @@ end normpath(@__DIR__, "../../generated_testmodels/minimal_subnetwork/ribasim.toml") @test ispath(toml_path) - config = Ribasim.Config(toml_path) - model = Ribasim.run(config) + model = Ribasim.run(toml_path) @test successful_retcode(model) (; u, p, t) = model.integrator (; user_demand) = p @@ -80,11 +76,8 @@ end "../../generated_testmodels/main_network_with_subnetworks/ribasim.toml", ) @test ispath(toml_path) - cfg = Ribasim.Config(toml_path) - db_path = Ribasim.database_path(cfg) - db = SQLite.DB(db_path) - p = Ribasim.Parameters(db, cfg) - close(db) + model = Ribasim.Model(toml_path) + p = model.integrator.p (; allocation, graph) = p (; main_network_connections, subnetwork_ids, allocation_models) = allocation @test Ribasim.has_main_network(allocation) @@ -214,11 +207,8 @@ end "../../generated_testmodels/subnetworks_with_sources/ribasim.toml", ) @test ispath(toml_path) - cfg = Ribasim.Config(toml_path) - db_path = Ribasim.database_path(cfg) - db = SQLite.DB(db_path) - p = Ribasim.Parameters(db, cfg) - close(db) + model = Ribasim.Model(toml_path) + p = model.integrator.p (; allocation, user_demand, graph, basin) = p (; allocation_models, subnetwork_demands, subnetwork_allocateds, mean_input_flows) = diff --git a/core/test/validation_test.jl b/core/test/validation_test.jl index e46c0ef74..bf596539c 100644 --- a/core/test/validation_test.jl +++ b/core/test/validation_test.jl @@ -471,20 +471,14 @@ end toml_path = normpath(@__DIR__, "../../generated_testmodels/basic/ribasim.toml") - cfg = Ribasim.Config(toml_path) - db_path = Ribasim.database_path(cfg) - db = SQLite.DB(db_path) + v = Ribasim.get_node_ids(toml_path) logger = TestLogger() with_logger(logger) do - @test_throws "Node ID #1 of type PidControl is not in the Node table." Ribasim.NodeID( - :PidControl, - 1, - db, - ) + @test_throws "Node ID is of the wrong type" Ribasim.NodeID(:PidControl, 1, v) end with_logger(logger) do - @test_throws "Node ID #20 is not in the Node table." Ribasim.NodeID(20, db) + @test_throws "Node ID not found" Ribasim.NodeID(:Pump, 20, v) end end