Skip to content

Commit

Permalink
Add generated, performant snake_case for NodeID.
Browse files Browse the repository at this point in the history
  • Loading branch information
evetion committed Dec 20, 2024
1 parent ac0920d commit 3cd6fcf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
12 changes: 12 additions & 0 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ const SolverStats = @NamedTuple{
5 Drainage = 6 Precipitation = 7
Base.to_index(id::Substance.T) = Int(id) # used to index into concentration matrices

@generated function config.snake_case(nt::NodeType.T)
ex = quote end
for (sym, _) in EnumX.symbol_map(NodeType.T)
sc = QuoteNode(config.snake_case(sym))
t = NodeType.T(sym)
push!(ex.args, :(nt === $t && return $sc))
end
push!(ex.args, :(return :nothing)) # type stability
ex
end

# Support creating a NodeType enum instance from a symbol or string
function NodeType.T(s::Symbol)::NodeType.T
symbol_map = EnumX.symbol_map(NodeType.T)
Expand Down Expand Up @@ -86,6 +97,7 @@ Base.convert(::Type{Int32}, id::NodeID) = id.value
Base.broadcastable(id::NodeID) = Ref(id)
Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.type == id_2.type && id_1.value == id_2.value
Base.show(io::IO, id::NodeID) = print(io, id.type, " #", id.value)
config.snake_case(id::NodeID) = config.snake_case(id.type)

function Base.isless(id_1::NodeID, id_2::NodeID)::Bool
if id_1.type != id_2.type
Expand Down
6 changes: 3 additions & 3 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ function get_variable_ref(
PreallocationRef(cache(1), flow_idx; from_du = true)
end
else
node = getfield(p, snake_case(Symbol(node_id.type)))
node = getfield(p, snake_case(node_id))
PreallocationRef(node.flow_rate, node_id.idx)
end
else
Expand Down Expand Up @@ -814,7 +814,7 @@ function collect_control_mappings!(p)::Nothing

for node_type in instances(NodeType.T)
node_type == NodeType.Terminal && continue
node = getfield(p, Symbol(snake_case(string(node_type))))
node = getfield(p, snake_case(node_type))
if hasfield(typeof(node), :control_mapping)
control_mappings[node_type] = node.control_mapping
end
Expand Down Expand Up @@ -1096,7 +1096,7 @@ function get_state_index(
component_name = if id.type == NodeType.UserDemand
inflow ? :user_demand_inflow : :user_demand_outflow
else
snake_case(Symbol(id.type))
snake_case(id)
end
for (comp, range) in pairs(NT)
if comp == component_name
Expand Down

0 comments on commit 3cd6fcf

Please sign in to comment.