From 3cd6fcfce90c99a752f69e7b1e0c5897be4eac06 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Fri, 20 Dec 2024 17:12:35 +0100 Subject: [PATCH] Add generated, performant snake_case for NodeID. --- core/src/parameter.jl | 12 ++++++++++++ core/src/util.jl | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 072005bf7..6002cc666 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -21,6 +21,17 @@ const SolverStats = @NamedTuple{ 5 Drainage = 6 Precipitation = 7 Base.to_index(id::Substance.T) = Int(id) # used to index into concentration matrices +@generated function config.snake_case(nt::NodeType.T) + ex = quote end + for (sym, _) in EnumX.symbol_map(NodeType.T) + sc = QuoteNode(config.snake_case(sym)) + t = NodeType.T(sym) + push!(ex.args, :(nt === $t && return $sc)) + end + push!(ex.args, :(return :nothing)) # type stability + ex +end + # Support creating a NodeType enum instance from a symbol or string function NodeType.T(s::Symbol)::NodeType.T symbol_map = EnumX.symbol_map(NodeType.T) @@ -86,6 +97,7 @@ Base.convert(::Type{Int32}, id::NodeID) = id.value Base.broadcastable(id::NodeID) = Ref(id) Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.type == id_2.type && id_1.value == id_2.value Base.show(io::IO, id::NodeID) = print(io, id.type, " #", id.value) +config.snake_case(id::NodeID) = config.snake_case(id.type) function Base.isless(id_1::NodeID, id_2::NodeID)::Bool if id_1.type != id_2.type diff --git a/core/src/util.jl b/core/src/util.jl index e9c2a2bcc..2ae1dbeff 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -657,7 +657,7 @@ function get_variable_ref( PreallocationRef(cache(1), flow_idx; from_du = true) end else - node = getfield(p, snake_case(Symbol(node_id.type))) + node = getfield(p, snake_case(node_id)) PreallocationRef(node.flow_rate, node_id.idx) end else @@ -814,7 +814,7 @@ function collect_control_mappings!(p)::Nothing for node_type in instances(NodeType.T) node_type == NodeType.Terminal && continue - node = getfield(p, Symbol(snake_case(string(node_type)))) + node = getfield(p, snake_case(node_type)) if hasfield(typeof(node), :control_mapping) control_mappings[node_type] = node.control_mapping end @@ -1096,7 +1096,7 @@ function get_state_index( component_name = if id.type == NodeType.UserDemand inflow ? :user_demand_inflow : :user_demand_outflow else - snake_case(Symbol(id.type)) + snake_case(id) end for (comp, range) in pairs(NT) if comp == component_name