-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up initialization #1977
Speed up initialization #1977
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,11 +31,17 @@ function NodeType.T(s::Symbol)::NodeType.T | |
end | ||
|
||
NodeType.T(str::AbstractString) = NodeType.T(Symbol(str)) | ||
NodeType.T(x::NodeType.T) = x | ||
Base.convert(::Type{NodeType.T}, x::String) = NodeType.T(x) | ||
Base.convert(::Type{NodeType.T}, x::Symbol) = NodeType.T(x) | ||
|
||
SQLite.esc_id(x::NodeType.T) = esc_id(string(x)) | ||
|
||
""" | ||
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, idx::Int) | ||
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, db::DB) | ||
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, p::Parameters) | ||
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, node_ids::Vector{NodeID}) | ||
|
||
NodeID is a unique identifier for a node in the model, as well as an index into the internal node type struct. | ||
|
||
|
@@ -52,42 +58,28 @@ This index can be passed directly, or calculated from the database or parameters | |
idx::Int | ||
end | ||
|
||
NodeID(type::Symbol, value::Integer, idx::Int) = NodeID(NodeType.T(type), value, idx) | ||
NodeID(type::AbstractString, value::Integer, idx::Int) = | ||
NodeID(NodeType.T(type), value, idx) | ||
|
||
function NodeID(type::Union{Symbol, AbstractString}, value::Integer, db::DB)::NodeID | ||
return NodeID(NodeType.T(type), value, db) | ||
end | ||
|
||
function NodeID(type::NodeType.T, value::Integer, db::DB)::NodeID | ||
node_type_string = string(type) | ||
# The index is equal to the number of nodes of the same type with a lower or equal ID | ||
idx = only( | ||
only( | ||
execute( | ||
columntable, | ||
db, | ||
"SELECT COUNT(*) FROM Node WHERE node_type == $(esc_id(node_type_string)) AND node_id <= $value", | ||
), | ||
), | ||
) | ||
if idx <= 0 | ||
error("Node ID #$value of type $type is not in the Node table.") | ||
function NodeID(node_type, value::Integer, node_ids::Vector{NodeID})::NodeID | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Often you know the |
||
node_type = NodeType.T(node_type) | ||
index = searchsortedfirst(node_ids, value; by = Int32) | ||
if index == lastindex(node_ids) + 1 | ||
@error "Node ID $node_type #$value is not in the Node table." | ||
error("Node ID not found") | ||
end | ||
node_id = node_ids[index] | ||
if node_id.type !== node_type | ||
@error "Requested node ID #$value is of type $(node_id.type), not $node_type" | ||
error("Node ID is of the wrong type") | ||
end | ||
return NodeID(type, value, idx) | ||
return node_id | ||
end | ||
|
||
function NodeID(value::Integer, db::DB)::NodeID | ||
(idx, type) = execute( | ||
columntable, | ||
db, | ||
"SELECT COUNT(*), node_type FROM Node WHERE node_type == (SELECT node_type FROM Node WHERE node_id == $value) AND node_id <= $value", | ||
) | ||
if only(idx) <= 0 | ||
error("Node ID #$value is not in the Node table.") | ||
function NodeID(value::Integer, node_ids::Vector{NodeID})::NodeID | ||
index = searchsortedfirst(node_ids, value; by = Int32) | ||
if index == lastindex(node_ids) + 1 | ||
@error "Node ID #$value is not in the Node table." | ||
error("Node ID not found") | ||
end | ||
return NodeID(only(type), value, only(idx)) | ||
return node_ids[index] | ||
end | ||
|
||
Base.Int32(id::NodeID) = id.value | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,9 +32,9 @@ function parse_static_and_time( | |
# of the current type | ||
vals_out = [] | ||
|
||
node_type_string = split(string(node_type), '.')[end] | ||
ids = get_ids(db, node_type_string) | ||
node_ids = NodeID.(node_type_string, ids, eachindex(ids)) | ||
node_type_string = String(split(string(node_type), '.')[end]) | ||
node_ids = get_node_ids(db, node_type_string) | ||
ids = Int32.(node_ids) | ||
n_nodes = length(node_ids) | ||
|
||
# Initialize the vectors for the output | ||
|
@@ -191,14 +191,14 @@ function static_and_time_node_ids( | |
db::DB, | ||
static::StructVector, | ||
time::StructVector, | ||
node_type::String, | ||
node_type::NodeType.T, | ||
)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool} | ||
ids = get_ids(db, node_type) | ||
node_ids = get_node_ids(db, node_type) | ||
ids = Int32.(node_ids) | ||
idx = searchsortedfirst.(Ref(ids), static.node_id) | ||
static_node_ids = Set(NodeID.(Ref(node_type), static.node_id, idx)) | ||
idx = searchsortedfirst.(Ref(ids), time.node_id) | ||
time_node_ids = Set(NodeID.(Ref(node_type), time.node_id, idx)) | ||
node_ids = NodeID.(Ref(node_type), ids, eachindex(ids)) | ||
doubles = intersect(static_node_ids, time_node_ids) | ||
errors = false | ||
if !isempty(doubles) | ||
|
@@ -287,7 +287,7 @@ function TabulatedRatingCurve( | |
time = load_structvector(db, config, TabulatedRatingCurveTimeV1) | ||
|
||
static_node_ids, time_node_ids, node_ids, valid = | ||
static_and_time_node_ids(db, static, time, "TabulatedRatingCurve") | ||
static_and_time_node_ids(db, static, time, NodeType.TabulatedRatingCurve) | ||
|
||
if !valid | ||
error( | ||
|
@@ -418,7 +418,8 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary | |
time = load_structvector(db, config, LevelBoundaryTimeV1) | ||
concentration_time = load_structvector(db, config, LevelBoundaryConcentrationV1) | ||
|
||
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "LevelBoundary") | ||
_, _, node_ids, valid = | ||
static_and_time_node_ids(db, static, time, NodeType.LevelBoundary) | ||
|
||
if !valid | ||
error("Problems encountered when parsing LevelBoundary static and time node IDs.") | ||
|
@@ -452,7 +453,8 @@ function FlowBoundary(db::DB, config::Config, graph::MetaGraph)::FlowBoundary | |
time = load_structvector(db, config, FlowBoundaryTimeV1) | ||
concentration_time = load_structvector(db, config, FlowBoundaryConcentrationV1) | ||
|
||
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "FlowBoundary") | ||
_, _, node_ids, valid = | ||
static_and_time_node_ids(db, static, time, NodeType.FlowBoundary) | ||
|
||
if !valid | ||
error("Problems encountered when parsing FlowBoundary static and time node IDs.") | ||
|
@@ -567,8 +569,8 @@ function Outlet(db::DB, config::Config, graph::MetaGraph)::Outlet | |
end | ||
|
||
function Terminal(db::DB, config::Config)::Terminal | ||
node_id = get_ids(db, "Terminal") | ||
return Terminal(NodeID.(NodeType.Terminal, node_id, eachindex(node_id))) | ||
node_id = get_node_ids(db, NodeType.Terminal) | ||
return Terminal(node_id) | ||
end | ||
|
||
function ConcentrationData( | ||
|
@@ -662,7 +664,7 @@ function ConcentrationData( | |
end | ||
|
||
function Basin(db::DB, config::Config, graph::MetaGraph)::Basin | ||
node_id = get_ids(db, "Basin") | ||
node_id = get_node_ids(db, NodeType.Basin) | ||
n = length(node_id) | ||
|
||
# both static and time are optional, but we need fallback defaults | ||
|
@@ -683,9 +685,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin | |
|
||
vertical_flux = ComponentVector(; table...) | ||
|
||
# Node IDs | ||
node_id = NodeID.(NodeType.Basin, node_id, eachindex(node_id)) | ||
|
||
# Profiles | ||
area, level = create_storage_tables(db, config) | ||
|
||
|
@@ -742,9 +741,10 @@ function CompoundVariable( | |
weight::Float64, | ||
look_ahead::Float64, | ||
}[] | ||
node_ids = get_node_ids(db) | ||
# Each row defines a subvariable | ||
for row in compound_variable_data | ||
listen_node_id = NodeID(row.listen_node_id, db) | ||
listen_node_id = NodeID(row.listen_node_id, node_ids) | ||
# Placeholder until actual ref is known | ||
variable_ref = PreallocationRef(placeholder_vector, 0) | ||
variable = row.variable | ||
|
@@ -757,7 +757,7 @@ function CompoundVariable( | |
end | ||
|
||
# The ID of the node listening to this CompoundVariable | ||
node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), db) | ||
node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), node_ids) | ||
return CompoundVariable(node_id, subvariables, greater_than) | ||
end | ||
|
||
|
@@ -811,8 +811,8 @@ function DiscreteControl(db::DB, config::Config, graph::MetaGraph)::DiscreteCont | |
condition = load_structvector(db, config, DiscreteControlConditionV1) | ||
compound_variable = load_structvector(db, config, DiscreteControlVariableV1) | ||
|
||
ids = get_ids(db, "DiscreteControl") | ||
node_id = NodeID.(:DiscreteControl, ids, eachindex(ids)) | ||
node_id = get_node_ids(db, NodeType.DiscreteControl) | ||
ids = Int32.(node_id) | ||
compound_variables, valid = | ||
parse_variables_and_conditions(compound_variable, condition, ids, db, graph) | ||
|
||
|
@@ -913,8 +913,8 @@ end | |
function ContinuousControl(db::DB, config::Config, graph::MetaGraph)::ContinuousControl | ||
compound_variable = load_structvector(db, config, ContinuousControlVariableV1) | ||
|
||
ids = get_ids(db, "ContinuousControl") | ||
node_id = NodeID.(:ContinuousControl, ids, eachindex(ids)) | ||
node_id = get_node_ids(db, NodeType.ContinuousControl) | ||
ids = Int32.(node_id) | ||
|
||
# Avoid using `function` as a variable name as that is recognized as a keyword | ||
func, controlled_variable, errors = continuous_control_functions(db, config, ids) | ||
|
@@ -940,7 +940,7 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl | |
static = load_structvector(db, config, PidControlStaticV1) | ||
time = load_structvector(db, config, PidControlTimeV1) | ||
|
||
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "PidControl") | ||
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.PidControl) | ||
|
||
if !valid | ||
error("Problems encountered when parsing PidControl static and time node IDs.") | ||
|
@@ -968,7 +968,8 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl | |
end | ||
controlled_basins = collect(controlled_basins) | ||
|
||
listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(db)) | ||
all_node_ids = get_node_ids(db) | ||
listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(all_node_ids)) | ||
|
||
return PidControl(; | ||
node_id = node_ids, | ||
|
@@ -1087,9 +1088,9 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand | |
static = load_structvector(db, config, UserDemandStaticV1) | ||
time = load_structvector(db, config, UserDemandTimeV1) | ||
concentration_time = load_structvector(db, config, UserDemandConcentrationV1) | ||
ids = get_ids(db, "UserDemand") | ||
|
||
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "UserDemand") | ||
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.UserDemand) | ||
ids = Int32.(node_ids) | ||
|
||
if !valid | ||
error("Problems encountered when parsing UserDemand static and time node IDs.") | ||
|
@@ -1229,14 +1230,15 @@ end | |
function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid | ||
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id)) | ||
tables = load_structvector(db, config, BasinSubgridV1) | ||
node_table = get_node_ids(db, NodeType.Basin) | ||
|
||
subgrid_ids = Int32[] | ||
basin_index = Int32[] | ||
interpolations = ScalarInterpolation[] | ||
has_error = false | ||
for group in IterTools.groupby(row -> row.subgrid_id, tables) | ||
subgrid_id = first(getproperty.(group, :subgrid_id)) | ||
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), db) | ||
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), node_table) | ||
basin_level = getproperty.(group, :basin_level) | ||
subgrid_level = getproperty.(group, :subgrid_level) | ||
|
||
|
@@ -1395,11 +1397,52 @@ function Parameters(db::DB, config::Config)::Parameters | |
return p | ||
end | ||
|
||
function get_ids(db::DB, nodetype)::Vector{Int32} | ||
sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(nodetype)) ORDER BY node_id" | ||
function get_node_ints(db::DB, node_type)::Vector{Int32} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: This signature is very similar to get_node_ids and I almost thought you made a typo. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha yeah good point, done. |
||
sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(node_type)) ORDER BY node_id" | ||
return only(execute(columntable, db, sql)) | ||
end | ||
|
||
function get_node_ids_types( | ||
db::DB, | ||
)::@NamedTuple{node_id::Vector{Int32}, node_type::Vector{NodeType.T}} | ||
sql = "SELECT node_id, node_type FROM Node ORDER BY node_id" | ||
table = execute(columntable, db, sql) | ||
# convert from String to NodeType | ||
node_type = NodeType.T.(table.node_type) | ||
return (; table.node_id, node_type) | ||
end | ||
|
||
function get_node_ids(db::DB)::Vector{NodeID} | ||
nt = get_node_ids_types(db) | ||
node_ids = Vector{Ribasim.NodeID}(undef, length(nt.node_id)) | ||
count = counter(Ribasim.NodeType.T) | ||
for (i, (; node_id, node_type)) in enumerate(Tables.rows(nt)) | ||
index = inc!(count, node_type) | ||
node_ids[i] = NodeID(node_type, node_id, index) | ||
end | ||
return node_ids | ||
end | ||
|
||
# Convenience method for tests | ||
function get_node_ids(toml_path::String)::Vector{NodeID} | ||
cfg = Config(toml_path) | ||
db_path = database_path(cfg) | ||
db = SQLite.DB(db_path) | ||
node_ids = get_node_ids(db) | ||
close(db) | ||
return node_ids | ||
end | ||
|
||
function get_node_ids(db::DB, node_type)::Vector{NodeID} | ||
node_type = NodeType.T(node_type) | ||
node_ints = get_node_ints(db, node_type) | ||
node_ids = Vector{Ribasim.NodeID}(undef, length(node_ints)) | ||
for (index, node_int) in enumerate(node_ints) | ||
node_ids[index] = NodeID(node_type, node_int, index) | ||
end | ||
return node_ids | ||
end | ||
|
||
function exists(db::DB, tablename::String) | ||
query = execute( | ||
db, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is responsible for the speedup, going from n database queries to 1.