From 3b7d691635b736f34b1e23f86bbcba63e2accc52 Mon Sep 17 00:00:00 2001
From: jalving <jhjalving@gmail.com>
Date: Tue, 29 Oct 2024 19:30:12 -0700
Subject: [PATCH 1/6] support variable-constrain and parameters

---
 src/Plasmo.jl               |   6 +-
 src/backends/moi_backend.jl |  49 +++++++--
 src/core_types.jl           |   2 -
 src/node_variables.jl       | 198 ++++++++++++++++++++++++------------
 src/optiedge.jl             |  10 +-
 src/optigraph.jl            |   6 +-
 src/optimizer_interface.jl  |   2 +-
 src/optinode.jl             |   2 +-
 src/utilities.jl            |  39 +++++++
 test/test_optigraph.jl      |   3 +
 10 files changed, 231 insertions(+), 86 deletions(-)
 create mode 100644 src/utilities.jl

diff --git a/src/Plasmo.jl b/src/Plasmo.jl
index cd875d1..1e192e2 100644
--- a/src/Plasmo.jl
+++ b/src/Plasmo.jl
@@ -29,6 +29,7 @@ export OptiGraph,
     OptiNode,
     OptiEdge,
     NodeVariableRef,
+    EdgeConstraintRef,
     direct_moi_graph,
     graph_backend,
     graph_index,
@@ -108,7 +109,8 @@ export OptiGraph,
 
     # other functions
 
-    set_jump_model
+    set_jump_model,
+    extract_variables
 
 include("core_types.jl")
 
@@ -138,6 +140,8 @@ include("graph_functions/topology.jl")
 
 include("graph_functions/partition.jl")
 
+include("utilities.jl")
+
 # extensions
 function __init__()
     @require KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" include(
diff --git a/src/backends/moi_backend.jl b/src/backends/moi_backend.jl
index f4ca62c..2dd76cb 100644
--- a/src/backends/moi_backend.jl
+++ b/src/backends/moi_backend.jl
@@ -528,22 +528,54 @@ end
 # MOI variables and constraints
 #
 
-function MOI.add_variable(graph_backend::GraphMOIBackend, vref::NodeVariableRef)
+function MOI.add_variable(backend::GraphMOIBackend, vref::NodeVariableRef)
     # return if variable already exists in backend
-    vref in keys(graph_backend.element_to_graph_map.var_map) && return nothing
+    vref in keys(backend.element_to_graph_map.var_map) && return nothing
 
     # add the variable
-    graph_var_index = MOI.add_variable(graph_backend.moi_backend)
+    graph_var_index = MOI.add_variable(backend.moi_backend)
 
     # map reference to index
-    graph_backend.element_to_graph_map[vref] = graph_var_index
-    graph_backend.graph_to_element_map[graph_var_index] = vref
+    backend.element_to_graph_map[vref] = graph_var_index
+    backend.graph_to_element_map[graph_var_index] = vref
 
     # create key for node if necessary
-    if !haskey(graph_backend.node_variables, vref.node)
-        graph_backend.node_variables[vref.node] = MOI.VariableIndex[]
+    if !haskey(backend.node_variables, vref.node)
+        backend.node_variables[vref.node] = MOI.VariableIndex[]
     end
-    push!(graph_backend.node_variables[vref.node], graph_var_index)
+    push!(backend.node_variables[vref.node], graph_var_index)
+    return graph_var_index
+end
+
+function MOI.add_constrained_variable(
+    backend::GraphMOIBackend,
+    vref::NodeVariableRef,
+    cref::NodeConstraintRef,
+    set::MOI.AbstractScalarSet,
+)
+    # return if variable already exists in backend
+    vref in keys(backend.element_to_graph_map.var_map) && return nothing
+
+    # add the variable and parameter constraint
+    graph_var_index, graph_con_index = MOI.add_constrained_variable(
+        backend.moi_backend, set
+    )
+
+    # map reference to index
+    backend.element_to_graph_map[vref] = graph_var_index
+    backend.graph_to_element_map[graph_var_index] = vref
+    backend.element_to_graph_map[cref] = graph_con_index
+    backend.graph_to_element_map[graph_con_index] = cref
+
+    # create key for node if necessary
+    if !haskey(backend.node_variables, vref.node)
+        backend.node_variables[vref.node] = MOI.VariableIndex[]
+    end
+    if !haskey(backend.element_constraints, vref.node)
+        graph_backend.element_constraints[vref.node] = MOI.ConstraintIndex[]
+    end
+    push!(backend.node_variables[vref.node], graph_var_index)
+    push!(backend.element_constraints[vref.node], graph_con_index)
     return graph_var_index
 end
 
@@ -879,6 +911,7 @@ function _copy_node_variables(
 
     # map existing variables in the index_map
     # existing variables may come from linking constraints added between graphs
+    # TODO: could be slow...
     existing_vars = intersect(node_variables, keys(dest.element_to_graph_map.var_map))
     for var in existing_vars
         src_graph_index = graph_index(var)
diff --git a/src/core_types.jl b/src/core_types.jl
index 19bd4cf..5398583 100644
--- a/src/core_types.jl
+++ b/src/core_types.jl
@@ -69,8 +69,6 @@ struct ElementData{GT<:AbstractOptiGraph}
     # track constraint indices
     last_constraint_index::OrderedDict{OptiElement,Int}
 end
-
-# default is OptiGraph
 function ElementData(GT::Type{<:AbstractOptiGraph})
     return ElementData{GT}(
         OrderedDict{OptiNode{GT},Vector{GT}}(),
diff --git a/src/node_variables.jl b/src/node_variables.jl
index 4ff88a7..3088522 100644
--- a/src/node_variables.jl
+++ b/src/node_variables.jl
@@ -3,7 +3,7 @@
 #  License, v. 2.0. If a copy of the MPL was not distributed with this
 #  file, You can obtain one at https://mozilla.org/MPL/2.0/.
 
-# TODO: parameterize on precision
+# TODO: parameterize variables on precision
 
 struct NodeVariableRef <: JuMP.AbstractVariableRef
     node::OptiNode
@@ -70,23 +70,20 @@ function MOI.delete(node::OptiNode, vref::NodeVariableRef)
     return nothing
 end
 
+# add variable
+
 """
     JuMP.add_variable(node::OptiNode, v::JuMP.AbstractVariable, name::String="")
 
 Add variable `v` to optinode `node`. This function supports use of the `@variable` JuMP macro.
 Optionally add a `base_name` to the variable for printing.
 """
-function JuMP.add_variable(node::OptiNode, v::JuMP.AbstractVariable, name::String="")
-    vref = _moi_add_node_variable(node, v)
-    if !isempty(name) && MOI.supports(
-        JuMP.backend(graph_backend(node)), MOI.VariableName(), MOI.VariableIndex
-    )
-        JuMP.set_name(vref, "$(JuMP.name(node))[:$(name)]")
-    end
-    return vref
+function JuMP.add_variable(node::OptiNode, v::JuMP.ScalarVariable, name::String="")
+    nvref = _moi_add_node_variable(node, v, name)
+    return nvref
 end
 
-function _moi_add_node_variable(node::OptiNode, v::JuMP.AbstractVariable)
+function _moi_add_node_variable(node::OptiNode, v::JuMP.ScalarVariable, name::String)
     # get a new variable index and create a reference
     variable_index = next_variable_index(node)
     nvref = NodeVariableRef(node, variable_index)
@@ -98,6 +95,11 @@ function _moi_add_node_variable(node::OptiNode, v::JuMP.AbstractVariable)
 
     # constrain node variable (hits all graph backends)
     _moi_constrain_node_variable(nvref, v.info, Float64)
+
+    if !isempty(name) &&
+        MOI.supports(JuMP.backend(node), MOI.VariableName(), MOI.VariableIndex)
+        JuMP.set_name(nvref, "$(JuMP.name(node))[:$(name)]")
+    end
     return nvref
 end
 
@@ -128,6 +130,64 @@ function _moi_constrain_node_variable(nvref::NodeVariableRef, info, ::Type{T}) w
     end
 end
 
+# add variable constrained on creation
+
+function JuMP.add_variable(
+    node::OptiNode, variable::VariableConstrainedOnCreation, name::String
+)
+    nvref = _moi_add_constrained_node_variable(
+        node, variable.scalar_variable, variable.set, name, Float64
+    )
+    return nvref
+end
+
+function JuMP.add_variable(
+    node::OptiNode,
+    variables::AbstractArray{<:VariableConstrainedOnCreation},
+    names::AbstractArray{<:String},
+)
+    return JuMP.add_variable.(node, variables, names)
+end
+
+function JuMP.add_variable(
+    node::OptiNode, variables::AbstractArray{<:VariableConstrainedOnCreation}, name::String
+)
+    return JuMP.add_variable.(node, variables, Ref(name))
+end
+
+function _moi_add_constrained_node_variable(
+    node::OptiNode,
+    scalar_variable::ScalarVariable,
+    set::MOI.AbstractScalarSet,
+    name::String,
+    ::Type{T},
+) where {T}
+    # get a new variable index and create a reference
+    variable_index = next_variable_index(node)
+    nvref = NodeVariableRef(node, variable_index)
+
+    # get a new constraint index and create a reference
+    constraint_index = next_constraint_index(
+        node, MOI.VariableIndex, typeof(set)
+    )::MOI.ConstraintIndex{MOI.VariableIndex,typeof(set)}
+    cref = ConstraintRef(node, constraint_index, JuMP.ScalarShape())
+
+    # add variable to all containing optigraphs
+    for graph in containing_optigraphs(node)
+        MOI.add_constrained_variable(JuMP.backend(graph), nvref, cref, set)
+    end
+
+    _moi_constrain_node_variable(nvref, scalar_variable.info, T)
+
+    if !isempty(name) &&
+        MOI.supports(JuMP.backend(node), MOI.VariableName(), MOI.VariableIndex)
+        JuMP.set_name(nvref, "$(JuMP.name(node))[:$(name)]")
+    end
+    return nvref
+end
+
+# variable methods
+
 function JuMP.delete(node::OptiNode, nvref::NodeVariableRef)
     if node !== JuMP.owner_model(nvref)
         error(
@@ -167,7 +227,7 @@ function JuMP.index(vref::NodeVariableRef)
     return vref.index
 end
 
-### variable values
+# variable primal values
 
 function JuMP.value(nvref::NodeVariableRef; result::Int=1)
     return MOI.get(graph_backend(nvref.node), MOI.VariablePrimal(result), nvref)
@@ -177,7 +237,43 @@ function JuMP.value(var_value::Function, vref::NodeVariableRef)
     return var_value(vref)
 end
 
-### variable start values
+# parameters
+
+function JuMP.ParameterRef(nvref::NodeVariableRef)
+    if !JuMP.is_parameter(nvref)
+        error("Variable $x is not a parameter.")
+    end
+    return ConstraintRef(JuMP.owner_model(nvref), _parameter_index(nvref), ScalarShape())
+end
+
+function JuMP.is_parameter(nvref::NodeVariableRef)
+    return MOI.is_valid(
+        JuMP.backend(JuMP.owner_model(nvref)), _parameter_index(nvref)
+    )::Bool
+end
+
+function JuMP.parameter_value(nvref::NodeVariableRef)
+    set = MOI.get(
+        JuMP.owner_model(nvref), MOI.ConstraintSet(), ParameterRef(nvref)
+    )::MOI.Parameter{JuMP.value_type(typeof(nvref))}
+    return set.value
+end
+
+function JuMP.set_parameter_value(nvref::NodeVariableRef, value)
+    node = JuMP.owner_model(nvref)
+    T = JuMP.value_type(typeof(nvref))
+    _set_dirty(node)
+    set = MOI.Parameter{T}(convert(T, value))
+    MOI.set(node, MOI.ConstraintSet(), ParameterRef(nvref), set)
+    return nothing
+end
+
+function _parameter_index(nvref::NodeVariableRef)
+    F, S = MOI.VariableIndex, MOI.Parameter{JuMP.value_type(typeof(nvref))}
+    return MOI.ConstraintIndex{F,S}(JuMP.index(nvref).value)
+end
+
+# variable start values
 
 function JuMP.start_value(nvref::NodeVariableRef)
     return MOI.get(graph_backend(nvref.node), MOI.VariablePrimalStart(), nvref)
@@ -192,7 +288,14 @@ function JuMP.set_start_value(nvref::NodeVariableRef, value::Union{Nothing,Real}
     )
 end
 
-### node variable bounds
+# variable bounds - lower bound
+
+function JuMP.LowerBoundRef(nvref::NodeVariableRef)
+    if !JuMP.has_lower_bound(nvref)
+        error("Variable $(nvref) does not have a lower bound.")
+    end
+    return _nv_lower_bound_ref(nvref)
+end
 
 function JuMP.has_lower_bound(nvref::NodeVariableRef)
     return _moi_nv_has_lower_bound(nvref)
@@ -220,13 +323,6 @@ function JuMP.delete_lower_bound(nvref::NodeVariableRef)
     return nothing
 end
 
-function JuMP.LowerBoundRef(nvref::NodeVariableRef)
-    if !JuMP.has_lower_bound(nvref)
-        error("Variable $(nvref) does not have a lower bound.")
-    end
-    return _nv_lower_bound_ref(nvref)
-end
-
 function _moi_nv_has_lower_bound(nvref::NodeVariableRef)
     backend = graph_backend(nvref.node)
     ci = MOI.ConstraintIndex{MOI.VariableIndex,MOI.GreaterThan{Float64}}(
@@ -258,6 +354,15 @@ function _moi_nv_set_lower_bound(nvref::NodeVariableRef, lower::Number)
     return nothing
 end
 
+# variable bounds - upper bound
+
+function JuMP.UpperBoundRef(nvref::NodeVariableRef)
+    if !JuMP.has_upper_bound(nvref)
+        error("Variable $(nvref) does not have an upper bound.")
+    end
+    return _nv_upper_bound_ref(nvref)
+end
+
 function JuMP.has_upper_bound(nvref::NodeVariableRef)
     return _moi_nv_has_upper_bound(nvref)
 end
@@ -284,13 +389,6 @@ function JuMP.delete_upper_bound(nvref::NodeVariableRef)
     return nothing
 end
 
-function JuMP.UpperBoundRef(nvref::NodeVariableRef)
-    if !JuMP.has_upper_bound(nvref)
-        error("Variable $(nvref) does not have an upper bound.")
-    end
-    return _nv_upper_bound_ref(nvref)
-end
-
 function _moi_nv_has_upper_bound(nvref::NodeVariableRef)
     backend = graph_backend(nvref.node)
     ci = MOI.ConstraintIndex{MOI.VariableIndex,MOI.LessThan{Float64}}(
@@ -322,7 +420,7 @@ function _moi_nv_set_upper_bound(nvref::NodeVariableRef, upper::Number)
     return nothing
 end
 
-### fix/unfix variable
+# fix/unfix variable
 
 function JuMP.FixRef(nvref::NodeVariableRef)
     if !JuMP.is_fixed(nvref)
@@ -404,7 +502,7 @@ function JuMP.unfix(nvref::NodeVariableRef)
     return nothing
 end
 
-### node variable integer
+# variable integer
 
 function JuMP.IntegerRef(nvref::NodeVariableRef)
     if !JuMP.is_integer(nvref)
@@ -460,7 +558,7 @@ function JuMP.unset_integer(nvref::NodeVariableRef)
     return nothing
 end
 
-### node variable binary
+# variable binary
 
 function JuMP.BinaryRef(nvref::NodeVariableRef)
     if !JuMP.is_binary(nvref)
@@ -516,7 +614,9 @@ function JuMP.unset_binary(nvref::NodeVariableRef)
     return nothing
 end
 
-# Extended from https://github.com/jump-dev/JuMP.jl/blob/301d46e81cb66c74c6e22cd89fb89ced740f157b/src/variables.jl#L2721
+# normalized coefficient
+
+## Extended from https://github.com/jump-dev/JuMP.jl/blob/301d46e81cb66c74c6e22cd89fb89ced740f157b/src/variables.jl#L2721
 function JuMP.set_normalized_coefficient(
     con_ref::S, variable::NodeVariableRef, value::Number
 ) where {S<:Union{NodeConstraintRef,EdgeConstraintRef}}
@@ -647,39 +747,3 @@ function JuMP.set_normalized_coefficient(
     graph.is_model_dirty = true
     return nothing
 end
-
-### Utilities for querying variables used in constraints
-
-function _extract_variables(func::NodeVariableRef)
-    return [func]
-end
-
-function _extract_variables(ref::ConstraintRef)
-    func = JuMP.jump_function(JuMP.constraint_object(ref))
-    return _extract_variables(func)
-end
-
-function _extract_variables(func::JuMP.GenericAffExpr)
-    return collect(keys(func.terms))
-end
-
-function _extract_variables(func::JuMP.GenericQuadExpr)
-    quad_vars = vcat([[term[2]; term[3]] for term in JuMP.quad_terms(func)]...)
-    aff_vars = _extract_variables(func.aff)
-    return union(quad_vars, aff_vars)
-end
-
-function _extract_variables(func::JuMP.GenericNonlinearExpr)
-    vars = NodeVariableRef[]
-    for i in 1:length(func.args)
-        func_arg = func.args[i]
-        if func_arg isa Number
-            continue
-        elseif typeof(func_arg) == NodeVariableRef
-            push!(vars, func_arg)
-        else
-            append!(vars, _extract_variables(func_arg))
-        end
-    end
-    return vars
-end
diff --git a/src/optiedge.jl b/src/optiedge.jl
index 2bed1d9..518875b 100644
--- a/src/optiedge.jl
+++ b/src/optiedge.jl
@@ -11,13 +11,13 @@ Base.show(io::IO, edge::OptiEdge) = Base.print(io, edge)
 
 function Base.setindex!(edge::OptiEdge, value::Any, name::Symbol)
     t = (edge, name)
-    source_graph(edge).edge_obj_dict[t] = value
+    source_graph(edge).element_data.edge_obj_dict[t] = value
     return nothing
 end
 
 function Base.getindex(edge::OptiEdge, name::Symbol)
     t = (edge, name)
-    return edge.source_graph.edge_obj_dict[t]
+    return source_graph(edge).element_data.edge_obj_dict[t]
 end
 
 """
@@ -72,7 +72,7 @@ end
 
 function JuMP.all_variables(edge::OptiEdge)
     con_refs = JuMP.all_constraints(edge)
-    vars = vcat(_extract_variables.(con_refs)...)
+    vars = vcat(extract_variables.(con_refs)...)
     return unique(vars)
 end
 
@@ -134,6 +134,10 @@ function JuMP.is_valid(edge::OptiEdge, cref::ConstraintRef)
     return edge === JuMP.owner_model(cref) && MOI.is_valid(graph_backend(edge), cref)
 end
 
+function get_edge(cref::EdgeConstraintRef)
+    return JuMP.owner_model(cref)
+end
+
 """
     JuMP.dual(cref::EdgeConstraintRef; result::Int=1)
 
diff --git a/src/optigraph.jl b/src/optigraph.jl
index 742f928..4ee979b 100644
--- a/src/optigraph.jl
+++ b/src/optigraph.jl
@@ -199,7 +199,7 @@ Add an existing optinode (created in another optigraph) to `graph`. This copies
 from the other graph to the new graph.
 """
 function add_node(graph::OptiGraph, node::OptiNode)
-    node in all_nodes(graph) && error("Node already exists within graph")
+    # node in all_nodes(graph) && error("Node already exists within graph")
     push!(graph.optinodes, node)
     add_node(graph_backend(graph), node)
     _track_node_in_graph(graph, node)
@@ -232,7 +232,7 @@ end
 Retrieve the optinodes contained in a JuMP expression.
 """
 function collect_nodes(jump_func::T where {T<:JuMP.AbstractJuMPScalar})
-    vars = _extract_variables(jump_func)
+    vars = extract_variables(jump_func)
     nodes = JuMP.owner_model.(vars)
     return collect(nodes)
 end
@@ -319,7 +319,7 @@ Add an existing optiedge (created in another optigraph) to `graph`. This copies
 from the other graph to the new graph.
 """
 function add_edge(graph::OptiGraph, edge::OptiEdge)
-    edge in all_edges(graph) && error("Cannot add the same edge to a graph multiple times")
+    # edge in all_edges(graph) && error("Cannot add the same edge to a graph multiple times")
     push!(graph.optiedges, edge)
     add_edge(graph_backend(graph), edge)
     _track_edge_in_graph(graph, edge)
diff --git a/src/optimizer_interface.jl b/src/optimizer_interface.jl
index 346cc2f..ee8d493 100644
--- a/src/optimizer_interface.jl
+++ b/src/optimizer_interface.jl
@@ -117,7 +117,7 @@ function JuMP.set_optimizer(
 )
     JuMP.error_if_direct_mode(JuMP.backend(graph), :set_optimizer)
     if add_bridges
-        optimizer = MOI.instantiate(optimizer_constructor)#; with_bridge_type = T)
+        optimizer = MOI.instantiate(optimizer_constructor; with_bridge_type=Float64)
         for BT in graph.bridge_types
             _moi_call_bridge_function(MOI.Bridges.add_bridge, optimizer, BT)
         end
diff --git a/src/optinode.jl b/src/optinode.jl
index dde6fa5..11a0e11 100644
--- a/src/optinode.jl
+++ b/src/optinode.jl
@@ -202,7 +202,7 @@ function _check_node_variables(
         NodeVariableRef,JuMP.GenericAffExpr,JuMP.GenericQuadExpr,JuMP.GenericNonlinearExpr
     },
 )
-    extract_vars = _extract_variables(jump_func)
+    extract_vars = extract_variables(jump_func)
     for var in extract_vars
         if var.node != node
             error("Variable $var does not belong to node $node")
diff --git a/src/utilities.jl b/src/utilities.jl
new file mode 100644
index 0000000..5b0cb8a
--- /dev/null
+++ b/src/utilities.jl
@@ -0,0 +1,39 @@
+### Utilities for querying variables used in constraints
+
+function extract_variables(func)
+    return _extract_variables(func)
+end
+
+function _extract_variables(func::NodeVariableRef)
+    return [func]
+end
+
+function _extract_variables(ref::EdgeConstraintRef)
+    func = JuMP.jump_function(JuMP.constraint_object(ref))
+    return _extract_variables(func)
+end
+
+function _extract_variables(func::JuMP.GenericAffExpr)
+    return collect(keys(func.terms))
+end
+
+function _extract_variables(func::JuMP.GenericQuadExpr)
+    quad_vars = vcat([[term[2]; term[3]] for term in JuMP.quad_terms(func)]...)
+    aff_vars = _extract_variables(func.aff)
+    return union(quad_vars, aff_vars)
+end
+
+function _extract_variables(func::JuMP.GenericNonlinearExpr)
+    vars = NodeVariableRef[]
+    for i in 1:length(func.args)
+        func_arg = func.args[i]
+        if func_arg isa Number
+            continue
+        elseif typeof(func_arg) == NodeVariableRef
+            push!(vars, func_arg)
+        else
+            append!(vars, _extract_variables(func_arg))
+        end
+    end
+    return vars
+end
diff --git a/test/test_optigraph.jl b/test/test_optigraph.jl
index 7eb4195..2d499fd 100644
--- a/test/test_optigraph.jl
+++ b/test/test_optigraph.jl
@@ -330,6 +330,9 @@ function test_variable_constraints()
     @variable(n1, x >= 1)
     @variable(n2, 0 <= x <= 2)
 
+    # parameter
+    @variable(n1, p in Parameter(1.0))
+
     # start value
     set_start_value(n2[:x], 3.0)
     @test start_value(n2[:x]) == 3.0

From 768a0c6faaaa844b68bb95e7751a69a9dedcaaf3 Mon Sep 17 00:00:00 2001
From: jalving <jhjalving@gmail.com>
Date: Tue, 29 Oct 2024 20:51:36 -0700
Subject: [PATCH 2/6] fix parameter issues

---
 src/backends/moi_backend.jl | 8 +++++---
 src/node_variables.jl       | 7 +++++--
 test/test_optigraph.jl      | 4 ++++
 3 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/src/backends/moi_backend.jl b/src/backends/moi_backend.jl
index 2dd76cb..260ac3e 100644
--- a/src/backends/moi_backend.jl
+++ b/src/backends/moi_backend.jl
@@ -269,11 +269,13 @@ function _add_edge(backend::GraphMOIBackend, edge::OptiEdge)
     return nothing
 end
 
-#
 # MOI Methods
-#
 
-### graph attributes
+## graph attributes
+
+function MOI.supports(backend::GraphMOIBackend, attr::MOI.AnyAttribute, args...)
+    return MOI.supports(JuMP.backend(backend), attr, args...)
+end
 
 function MOI.get(
     backend::GraphMOIBackend, attr::AT
diff --git a/src/node_variables.jl b/src/node_variables.jl
index 3088522..393d8b0 100644
--- a/src/node_variables.jl
+++ b/src/node_variables.jl
@@ -243,7 +243,10 @@ function JuMP.ParameterRef(nvref::NodeVariableRef)
     if !JuMP.is_parameter(nvref)
         error("Variable $x is not a parameter.")
     end
-    return ConstraintRef(JuMP.owner_model(nvref), _parameter_index(nvref), ScalarShape())
+    backend = JuMP.backend(nvref.node)
+    ci = _parameter_index(nvref)
+    cref = JuMP.constraint_ref_with_index(backend, ci)
+    return cref
 end
 
 function JuMP.is_parameter(nvref::NodeVariableRef)
@@ -270,7 +273,7 @@ end
 
 function _parameter_index(nvref::NodeVariableRef)
     F, S = MOI.VariableIndex, MOI.Parameter{JuMP.value_type(typeof(nvref))}
-    return MOI.ConstraintIndex{F,S}(JuMP.index(nvref).value)
+    return MOI.ConstraintIndex{F,S}(graph_index(nvref).value)
 end
 
 # variable start values
diff --git a/test/test_optigraph.jl b/test/test_optigraph.jl
index 2d499fd..97b5de6 100644
--- a/test/test_optigraph.jl
+++ b/test/test_optigraph.jl
@@ -333,6 +333,10 @@ function test_variable_constraints()
     # parameter
     @variable(n1, p in Parameter(1.0))
 
+    @test parameter_value(p) == 1.0
+    set_parameter_value(p, 2.0)
+    @test parameter_value(p) == 2.0
+
     # start value
     set_start_value(n2[:x], 3.0)
     @test start_value(n2[:x]) == 3.0

From a002af5a949223f824eb6cd3934c8d90f3d813ad Mon Sep 17 00:00:00 2001
From: jalving <jhjalving@gmail.com>
Date: Fri, 1 Nov 2024 17:50:56 -0700
Subject: [PATCH 3/6] more utility funcs

---
 src/Plasmo.jl    |   6 +-
 src/optigraph.jl |  87 ++++++++++++++++--
 src/optinode.jl  |  10 ++-
 src/utilities.jl |  39 --------
 src/utils.jl     | 225 +++++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 318 insertions(+), 49 deletions(-)
 delete mode 100644 src/utilities.jl
 create mode 100644 src/utils.jl

diff --git a/src/Plasmo.jl b/src/Plasmo.jl
index 1e192e2..ee624f3 100644
--- a/src/Plasmo.jl
+++ b/src/Plasmo.jl
@@ -110,7 +110,9 @@ export OptiGraph,
     # other functions
 
     set_jump_model,
-    extract_variables
+    extract_variables,
+    is_separable,
+    extract_separable_terms
 
 include("core_types.jl")
 
@@ -140,7 +142,7 @@ include("graph_functions/topology.jl")
 
 include("graph_functions/partition.jl")
 
-include("utilities.jl")
+include("utils.jl")
 
 # extensions
 function __init__()
diff --git a/src/optigraph.jl b/src/optigraph.jl
index 4ee979b..8a05171 100644
--- a/src/optigraph.jl
+++ b/src/optigraph.jl
@@ -1091,27 +1091,100 @@ function has_node_objective(graph::OptiGraph)
     return false
 end
 
+"""
+    node_objective_type(graph::OptiGraph)
+
+Return the most complex objective type among nodes in the given `graph`. The order of
+complexity is: Nonlinear, Quadratic, Linear.
+"""
+function node_objective_type(graph::OptiGraph)
+    if !(has_node_objective(graph))
+        return nothing
+    end
+    
+    obj_types = JuMP.objective_function_type.(all_nodes(graph))
+    if JuMP.GenericNonlinearExpr{NodeVariableRef} in obj_types
+        return JuMP.GenericNonlinearExpr{NodeVariableRef}
+    elseif JuMP.GenericQuadExpr{Float64,NodeVariableRef} in obj_types
+        return JuMP.GenericQuadExpr{Float64,NodeVariableRef}
+    elseif JuMP.GenericAffExpr{Float64,NodeVariableRef} in obj_types
+        return JuMP.GenericAffExpr{Float64,NodeVariableRef}
+    elseif NodeVariableRef in obj_types
+        return JuMP.GenericAffExpr{Float64,NodeVariableRef}
+    else
+        error("Could not determine node objective type")
+    end
+end
+
 """
     set_to_node_objectives(graph::OptiGraph)
 
 Set the `graph` objective to the summation of all of its optinode objectives. Assumes the 
-objective sense is an MOI.MIN_SENSE and adjusts the signs of node objective functions 
-accordingly.
+objective sense is an MOI.MIN_SENSE and accounts for the sense of node objectives 
+accordingly. 
+
+Note that building nonlinear objective functions is much slower than 
+linear or quadratic because nonlienar expressions cannot be updated in place.
 """
 function set_to_node_objectives(graph::OptiGraph)
-    obj = 0
+    if has_node_objective(graph)
+        node_obj_type =  node_objective_type(graph)
+        _set_to_node_objectives(graph, node_obj_type)
+    end
+    return nothing
+end
+
+function _set_to_node_objectives(
+    graph::OptiGraph, 
+    obj_type::Type{T} where T <: Union{
+        JuMP.GenericAffExpr{Float64, NodeVariableRef},
+        JuMP.GenericQuadExpr{Float64, NodeVariableRef}
+    }
+)
+    objective = zero(obj_type)
     for node in all_nodes(graph)
         if has_objective(node)
             sense = JuMP.objective_sense(node) == MOI.MAX_SENSE ? -1 : 1
-            obj += sense * JuMP.objective_function(node)
+            JuMP.add_to_expression!(objective, JuMP.objective_function(node), sense)
         end
     end
-    if obj != 0
-        @objective(graph, Min, obj)
+    @objective(graph, Min, objective)
+    return
+end
+
+function _set_to_node_objectives(
+    graph::OptiGraph, 
+    obj_type::Type{T} where T <: JuMP.GenericNonlinearExpr{NodeVariableRef}
+)
+    objective = zero(obj_type)
+    for node in all_nodes(graph)
+        if has_objective(node)
+            sense = JuMP.objective_sense(node) == MOI.MAX_SENSE ? -1 : 1
+            objective += *(sense, objective_function(node))
+        end
     end
-    return nothing
+    @objective(graph, Min, objective)
+    return
 end
 
+# TODO
+"""
+    set_node_objectives_from_graph(graph::OptiGraph)
+
+Set the objective of each node within `graph` by parsing and separating the graph objective
+function. Note this only works if the objective function is separable over the nodes in 
+`graph`.
+"""
+# function set_node_objectives_from_graph(graph::OptiGraph)
+#     obj = objective_function(graph)
+#     if !(is_separable(obj))
+#         error("Cannot set node objectives from graph. It is not separable across nodes.")
+#     end
+#     sense = objective_sense(graph)
+#     _set_node_objectives_from_graph(obj, sense)
+#     return nothing
+# end
+
 """
     JuMP.objective_function(graph::OptiGraph)
 
diff --git a/src/optinode.jl b/src/optinode.jl
index 11a0e11..b413799 100644
--- a/src/optinode.jl
+++ b/src/optinode.jl
@@ -241,7 +241,15 @@ function JuMP.set_objective_sense(node::OptiNode, sense::MOI.OptimizationSense)
 end
 
 function JuMP.objective_function(node::OptiNode)
-    return JuMP.object_dictionary(node)[(node, :objective_function)]
+    if haskey(JuMP.object_dictionary(node), (node,:objective_function))
+        return JuMP.object_dictionary(node)[(node, :objective_function)]
+    else
+        return nothing
+    end
+end
+
+function JuMP.objective_function_type(node::OptiNode)
+    return typeof(objective_function(node))
 end
 
 function JuMP.objective_sense(node::OptiNode)
diff --git a/src/utilities.jl b/src/utilities.jl
deleted file mode 100644
index 5b0cb8a..0000000
--- a/src/utilities.jl
+++ /dev/null
@@ -1,39 +0,0 @@
-### Utilities for querying variables used in constraints
-
-function extract_variables(func)
-    return _extract_variables(func)
-end
-
-function _extract_variables(func::NodeVariableRef)
-    return [func]
-end
-
-function _extract_variables(ref::EdgeConstraintRef)
-    func = JuMP.jump_function(JuMP.constraint_object(ref))
-    return _extract_variables(func)
-end
-
-function _extract_variables(func::JuMP.GenericAffExpr)
-    return collect(keys(func.terms))
-end
-
-function _extract_variables(func::JuMP.GenericQuadExpr)
-    quad_vars = vcat([[term[2]; term[3]] for term in JuMP.quad_terms(func)]...)
-    aff_vars = _extract_variables(func.aff)
-    return union(quad_vars, aff_vars)
-end
-
-function _extract_variables(func::JuMP.GenericNonlinearExpr)
-    vars = NodeVariableRef[]
-    for i in 1:length(func.args)
-        func_arg = func.args[i]
-        if func_arg isa Number
-            continue
-        elseif typeof(func_arg) == NodeVariableRef
-            push!(vars, func_arg)
-        else
-            append!(vars, _extract_variables(func_arg))
-        end
-    end
-    return vars
-end
diff --git a/src/utils.jl b/src/utils.jl
new file mode 100644
index 0000000..0f961ea
--- /dev/null
+++ b/src/utils.jl
@@ -0,0 +1,225 @@
+"""
+    extract_variables(func)
+
+Return the variables contained within the given expression or reference.
+"""
+function extract_variables(func)
+    return _extract_variables(func)
+end
+
+function _extract_variables(func::NodeVariableRef)
+    return [func]
+end
+
+function _extract_variables(ref::EdgeConstraintRef)
+    func = JuMP.jump_function(JuMP.constraint_object(ref))
+    return _extract_variables(func)
+end
+
+function _extract_variables(func::JuMP.GenericAffExpr)
+    return collect(keys(func.terms))
+end
+
+function _extract_variables(func::JuMP.GenericQuadExpr)
+    quad_vars = vcat([[term[2]; term[3]] for term in JuMP.quad_terms(func)]...)
+    aff_vars = _extract_variables(func.aff)
+    return union(quad_vars, aff_vars)
+end
+
+function _extract_variables(func::JuMP.GenericNonlinearExpr)
+    vars = NodeVariableRef[]
+    for i in 1:length(func.args)
+        func_arg = func.args[i]
+        if func_arg isa Number
+            continue
+        elseif typeof(func_arg) == NodeVariableRef
+            push!(vars, func_arg)
+        else
+            append!(vars, _extract_variables(func_arg))
+        end
+    end
+    return vars
+end
+
+function _first_variable(func::JuMP.GenericNonlinearExpr)
+    for i in 1:length(func.args)
+        func_arg = func.args[i]
+        if func_arg isa Number
+            continue
+        elseif typeof(func_arg) == NodeVariableRef
+            return func_arg
+        else
+            return _first_variable(func_arg)
+        end
+    end 
+end
+
+"""
+    is_separable(func)
+
+Return whether the given function is separable across optinodes.
+"""
+function is_separable(func::Union{Number,JuMP.AbstractJuMPScalar})
+    return _is_separable(func)
+end
+
+function _is_separable(::Number)
+    return true
+end
+
+function _is_separable(::NodeVariableRef)
+    return true
+end
+
+function _is_separable(::JuMP.GenericAffExpr{<:Number,NodeVariableRef})
+    return true
+end
+
+function _is_separable(func::JuMP.GenericQuadExpr{<:Number,NodeVariableRef})
+    # check each term; make sure they are all on the same subproblem
+    for term in Plasmo.quad_terms(func)
+        # term = (coefficient, variable_1, variable_2)
+        node1 = get_node(term[2])
+        node2 = get_node(term[3])
+
+        # if any term is split across nodes, the objective is not separable
+        if node1 != node2
+            return false
+        end
+    end
+    return true
+end
+
+function _is_separable(func::JuMP.GenericNonlinearExpr{NodeVariableRef})
+    # check for a constant multiplier
+    if func.head == :*
+        if !(func.args[1] isa Number)
+            return false
+        end
+    end
+
+    # if not additive, check if term is separable
+    if func.head != :+ && func.head != :-
+        vars = extract_variables(func)
+        nodes = get_node.(vars)
+        if length(unique(nodes)) > 1
+            return false
+        end
+    end
+
+    # check each argument
+    for arg in func.args
+        if !(is_separable(arg))
+            return false
+        end
+    end
+    return true
+end
+
+"""
+    extract_separable_terms(func::JuMP.AbstractJuMPScalar,graph::OptiGraph)
+
+Extract the separable terms contained within `graph`.
+NOTE: Nonlinear objectives are not completely tested and may return incorrect results.
+"""
+function extract_separable_terms(func::JuMP.AbstractJuMPScalar, graph::OptiGraph)
+    !is_separable(func) && error("Cannont extract terms. Function is not separable.")
+    return _extract_separable_terms(func, graph)
+end
+
+function _extract_separable_terms(
+    func::Union{Number,Plasmo.NodeVariableRef},
+    graph::OptiGraph
+)
+    return func
+end
+
+function _extract_separable_terms(
+    func::JuMP.GenericAffExpr{<:Number,NodeVariableRef},
+    graph::OptiGraph
+)
+    node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericAffExpr{<:Number,NodeVariableRef}}}()
+    nodes = Plasmo.collect_nodes(func)
+    nodes = intersect(nodes, all_nodes(graph))
+    for node in nodes
+        node_terms[node] = Vector{JuMP.GenericAffExpr{<:Number,NodeVariableRef}}()
+    end
+
+    for term in Plasmo.linear_terms(func)
+        node = get_node(term[2])
+        push!(node_terms[node], term[1]*term[2])
+    end
+
+    return node_terms
+end
+
+function _extract_separable_terms(
+    func::JuMP.GenericQuadExpr{<:Number,NodeVariableRef},
+    graph::OptiGraph
+)
+    node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericQuadExpr{<:Number,NodeVariableRef}}}()
+    nodes = collect_nodes(func)
+    nodes = intersect(nodes, all_nodes(graph))
+    for node in nodes
+        node_terms[node] = Vector{JuMP.GenericQuadExpr{<:Number,NodeVariableRef}}()
+    end
+
+    for term in JuMP.quad_terms(func)
+        node = get_node(term[2])
+        push!(node_terms[node], term[1]*term[2]*term[3])
+    end
+
+    for term in JuMP.linear_terms(func)
+        node = get_node(term[2])
+        push!(node_terms[node], term[1]*term[2])
+    end
+
+    return node_terms
+end
+
+# NOTE: method needs improvement. does not cover all separable cases.
+function _extract_separable_terms(
+    func::JuMP.GenericNonlinearExpr{NodeVariableRef},
+    graph::OptiGraph
+)
+    node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}}()
+    nodes = collect_nodes(func)
+    nodes = intersect(nodes, all_nodes(graph))
+    for node in nodes
+        node_terms[node] = Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}()
+    end
+
+    _extract_separable_terms(func, node_terms)
+
+    return node_terms
+end
+
+function _extract_separable_terms(
+    func::JuMP.GenericNonlinearExpr{NodeVariableRef}, 
+    node_terms::OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}}
+)
+    # check for a constant multiplier
+    multiplier = 1.0
+    if func.head == :*
+        if func.args[1] isa Number
+            multiplier = func.args[1]
+        end
+    end
+
+    # if not additive, get node for this term
+    if func.head != :+ && func.head != :-
+        var = _first_variable(func)
+        node = get_node(var)
+        push!(node_terms[node], multiplier*func)
+    else
+        # check each argument
+        for arg in func.args
+            if arg isa Number
+                continue
+            end
+            _extract_separable_terms(arg, node_terms)
+        end
+    end
+
+    return nothing
+end
\ No newline at end of file

From 21614fce881b1f87fd30fbcdf4de53f57c58df2e Mon Sep 17 00:00:00 2001
From: jalving <jhjalving@gmail.com>
Date: Fri, 1 Nov 2024 21:12:08 -0700
Subject: [PATCH 4/6] linting

---
 src/optigraph.jl       | 120 ++++++++++++++++++++---------------------
 src/optinode.jl        |   2 +-
 src/utils.jl           |  36 ++++++-------
 test/test_optigraph.jl |   2 +-
 4 files changed, 79 insertions(+), 81 deletions(-)

diff --git a/src/optigraph.jl b/src/optigraph.jl
index 8a05171..5867d99 100644
--- a/src/optigraph.jl
+++ b/src/optigraph.jl
@@ -90,9 +90,7 @@ Base.broadcastable(graph::OptiGraph) = Ref(graph)
 # TODO: parameterize on numerical precision like JuMP Models do
 JuMP.value_type(::Type{OptiGraph}) = Float64
 
-#
-# Optigraph methods
-#
+# optigraph methods
 
 """
     graph_backend(graph::OptiGraph)
@@ -105,7 +103,7 @@ function graph_backend(graph::OptiGraph)
     return graph.backend
 end
 
-### Graph Index
+# graph index
 
 """
     graph_index(ref::RT) where {RT<:Union{NodeVariableRef,ConstraintRef}}
@@ -121,7 +119,7 @@ function graph_index(
     return graph_index(graph_backend(graph), ref)
 end
 
-### Assemble OptiGraph
+# assemble optiGraph
 
 function _assemble_optigraph(nodes::Vector{<:OptiNode}, edges::Vector{<:OptiEdge})
     graph = OptiGraph()
@@ -171,7 +169,7 @@ function is_valid_optigraph(nodes::Vector{<:OptiNode}, edges::Vector{<:OptiEdge}
     return isempty(setdiff(edge_nodes, nodes)) ? true : false
 end
 
-### Manage OptiNodes
+# manage optinodes
 
 """
     add_node(
@@ -284,7 +282,7 @@ function num_nodes(graph::OptiGraph)
     return n_nodes
 end
 
-### Manage OptiEdges
+# manage optiEdges
 
 """
     add_edge(
@@ -460,7 +458,7 @@ function all_elements(graph::OptiGraph)
     return [all_nodes(graph); all_edges(graph)]
 end
 
-### Manage subgraphs
+# manage subgraphs
 
 """
     add_subgraph(graph::OptiGraph; name::Symbol=Symbol(:sg,gensym()))
@@ -549,7 +547,7 @@ function num_subgraphs(graph::OptiGraph)
     return n_subs
 end
 
-### Link Constraints
+# link constraints
 
 """
     num_local_link_constraints(
@@ -664,7 +662,7 @@ function all_link_constraints(graph::OptiGraph)
     return vcat(all_constraints.(all_edges(graph))...)
 end
 
-### Local Constraints
+# local constraints
 
 """
     num_local_constraints(
@@ -732,13 +730,7 @@ function local_constraints(graph::OptiGraph)
     return vcat(all_constraints.(local_elements(graph))...)
 end
 
-# TODO Methods
-# num_linked_variables(graph)
-# linked_variables(graph)
-
-#
-# MOI Methods
-#
+# MOI methods
 
 function MOI.get(
     graph::OptiGraph, attr::AT
@@ -752,9 +744,7 @@ function MOI.set(
     return MOI.set(graph_backend(graph), attr, args...)
 end
 
-#
-# JuMP Methods
-#
+# JuMP methods
 
 """
     JuMP.name(graph::OptiGraph)
@@ -775,7 +765,7 @@ function JuMP.set_name(graph::OptiGraph, name::Symbol)
     return nothing
 end
 
-### Variables
+# variable methods
 
 """
     JuMP.all_variables(graph::OptiGraph)
@@ -896,7 +886,7 @@ function JuMP.dual(graph::OptiGraph, cref::EdgeConstraintRef; result::Int=1)
     return MOI.get(graph_backend(graph), MOI.ConstraintDual(result), cref)
 end
 
-### Constraints
+# constraint methods
 
 """
     JuMP.add_constraint(graph::OptiGraph, con::JuMP.AbstractConstraint, name::String="")
@@ -1003,7 +993,7 @@ function JuMP.num_constraints(graph::OptiGraph; count_variable_in_set_constraint
     return num_cons
 end
 
-### Other Methods
+# other methods
 
 """
     JuMP.backend(graph::OptiGraph)
@@ -1039,7 +1029,7 @@ function JuMP.relax_integrality(graph::OptiGraph)
     return unrelax
 end
 
-### Nonlinear Operators
+# nonlinear operators
 
 """
     JuMP.add_nonlinear_operator(
@@ -1075,7 +1065,7 @@ function JuMP.add_nonlinear_operator(
     return JuMP.NonlinearOperator(f, registered_name)
 end
 
-### Objective function
+# objective function
 
 """
     has_node_objective(graph::OptiGraph)
@@ -1101,7 +1091,7 @@ function node_objective_type(graph::OptiGraph)
     if !(has_node_objective(graph))
         return nothing
     end
-    
+
     obj_types = JuMP.objective_function_type.(all_nodes(graph))
     if JuMP.GenericNonlinearExpr{NodeVariableRef} in obj_types
         return JuMP.GenericNonlinearExpr{NodeVariableRef}
@@ -1128,18 +1118,22 @@ linear or quadratic because nonlienar expressions cannot be updated in place.
 """
 function set_to_node_objectives(graph::OptiGraph)
     if has_node_objective(graph)
-        node_obj_type =  node_objective_type(graph)
+        node_obj_type = node_objective_type(graph)
         _set_to_node_objectives(graph, node_obj_type)
     end
     return nothing
 end
 
 function _set_to_node_objectives(
-    graph::OptiGraph, 
-    obj_type::Type{T} where T <: Union{
-        JuMP.GenericAffExpr{Float64, NodeVariableRef},
-        JuMP.GenericQuadExpr{Float64, NodeVariableRef}
-    }
+    graph::OptiGraph,
+    obj_type::Type{
+        T
+    } where {
+        T<:Union{
+            JuMP.GenericAffExpr{Float64,NodeVariableRef},
+            JuMP.GenericQuadExpr{Float64,NodeVariableRef},
+        },
+    },
 )
     objective = zero(obj_type)
     for node in all_nodes(graph)
@@ -1149,12 +1143,12 @@ function _set_to_node_objectives(
         end
     end
     @objective(graph, Min, objective)
-    return
+    return nothing
 end
 
 function _set_to_node_objectives(
-    graph::OptiGraph, 
-    obj_type::Type{T} where T <: JuMP.GenericNonlinearExpr{NodeVariableRef}
+    graph::OptiGraph,
+    obj_type::Type{T} where {T<:JuMP.GenericNonlinearExpr{NodeVariableRef}},
 )
     objective = zero(obj_type)
     for node in all_nodes(graph)
@@ -1164,27 +1158,9 @@ function _set_to_node_objectives(
         end
     end
     @objective(graph, Min, objective)
-    return
+    return nothing
 end
 
-# TODO
-"""
-    set_node_objectives_from_graph(graph::OptiGraph)
-
-Set the objective of each node within `graph` by parsing and separating the graph objective
-function. Note this only works if the objective function is separable over the nodes in 
-`graph`.
-"""
-# function set_node_objectives_from_graph(graph::OptiGraph)
-#     obj = objective_function(graph)
-#     if !(is_separable(obj))
-#         error("Cannot set node objectives from graph. It is not separable across nodes.")
-#     end
-#     sense = objective_sense(graph)
-#     _set_node_objectives_from_graph(obj, sense)
-#     return nothing
-# end
-
 """
     JuMP.objective_function(graph::OptiGraph)
 
@@ -1308,7 +1284,7 @@ function _moi_set_objective_function(graph::OptiGraph, expr::JuMP.AbstractJuMPSc
     return nothing
 end
 
-### objective coefficient - linear
+# objective coefficient - linear
 
 """
     JuMP.set_objective_coefficient(
@@ -1355,7 +1331,7 @@ function _set_objective_coefficient(
     return nothing
 end
 
-### objective coefficient - linear - vector
+# objective coefficient - linear - vector
 
 function JuMP.set_objective_coefficient(
     graph::OptiGraph,
@@ -1403,7 +1379,7 @@ function _set_objective_coefficient(
     return nothing
 end
 
-### objective coefficient - quadratic
+# objective coefficient - quadratic
 
 function JuMP.set_objective_coefficient(
     graph::OptiGraph, variable_1::NodeVariableRef, variable_2::NodeVariableRef, coeff::Real
@@ -1429,7 +1405,7 @@ function _set_objective_coefficient(
     return nothing
 end
 
-# if existing objective is quadratic
+## if existing objective is quadratic
 function _set_objective_coefficient(
     graph::OptiGraph,
     variable_1::NodeVariableRef,
@@ -1450,7 +1426,7 @@ function _set_objective_coefficient(
     return nothing
 end
 
-### objective coefficient - quadratic - vector
+# objective coefficient - quadratic - vector
 
 function JuMP.set_objective_coefficient(
     graph::OptiGraph,
@@ -1470,7 +1446,7 @@ function JuMP.set_objective_coefficient(
     return nothing
 end
 
-# if existing objective is not quadratic
+## if existing objective is not quadratic
 function _set_objective_coefficient(
     graph::OptiGraph,
     variables_1::AbstractVector{<:NodeVariableRef},
@@ -1487,7 +1463,7 @@ function _set_objective_coefficient(
     return nothing
 end
 
-# if existing objective is quadratic
+## if existing objective is quadratic
 function _set_objective_coefficient(
     graph::OptiGraph,
     variables_1::AbstractVector{<:NodeVariableRef},
@@ -1513,3 +1489,25 @@ end
 function JuMP.unregister(graph::OptiGraph, key::Symbol)
     return delete!(object_dictionary(graph), key)
 end
+
+# TODO Methods
+# num_linked_variables(graph)
+# linked_variables(graph)
+
+# TODO
+"""
+    set_node_objectives_from_graph(graph::OptiGraph)
+
+Set the objective of each node within `graph` by parsing and separating the graph objective
+function. Note this only works if the objective function is separable over the nodes in 
+`graph`.
+"""
+# function set_node_objectives_from_graph(graph::OptiGraph)
+#     obj = objective_function(graph)
+#     if !(is_separable(obj))
+#         error("Cannot set node objectives from graph. It is not separable across nodes.")
+#     end
+#     sense = objective_sense(graph)
+#     _set_node_objectives_from_graph(obj, sense)
+#     return nothing
+# end
diff --git a/src/optinode.jl b/src/optinode.jl
index b413799..6e783f7 100644
--- a/src/optinode.jl
+++ b/src/optinode.jl
@@ -241,7 +241,7 @@ function JuMP.set_objective_sense(node::OptiNode, sense::MOI.OptimizationSense)
 end
 
 function JuMP.objective_function(node::OptiNode)
-    if haskey(JuMP.object_dictionary(node), (node,:objective_function))
+    if haskey(JuMP.object_dictionary(node), (node, :objective_function))
         return JuMP.object_dictionary(node)[(node, :objective_function)]
     else
         return nothing
diff --git a/src/utils.jl b/src/utils.jl
index 0f961ea..22dd15d 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -51,7 +51,7 @@ function _first_variable(func::JuMP.GenericNonlinearExpr)
         else
             return _first_variable(func_arg)
         end
-    end 
+    end
 end
 
 """
@@ -128,17 +128,17 @@ function extract_separable_terms(func::JuMP.AbstractJuMPScalar, graph::OptiGraph
 end
 
 function _extract_separable_terms(
-    func::Union{Number,Plasmo.NodeVariableRef},
-    graph::OptiGraph
+    func::Union{Number,Plasmo.NodeVariableRef}, graph::OptiGraph
 )
     return func
 end
 
 function _extract_separable_terms(
-    func::JuMP.GenericAffExpr{<:Number,NodeVariableRef},
-    graph::OptiGraph
+    func::JuMP.GenericAffExpr{<:Number,NodeVariableRef}, graph::OptiGraph
 )
-    node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericAffExpr{<:Number,NodeVariableRef}}}()
+    node_terms = OrderedDict{
+        OptiNode,Vector{JuMP.GenericAffExpr{<:Number,NodeVariableRef}}
+    }()
     nodes = Plasmo.collect_nodes(func)
     nodes = intersect(nodes, all_nodes(graph))
     for node in nodes
@@ -147,17 +147,18 @@ function _extract_separable_terms(
 
     for term in Plasmo.linear_terms(func)
         node = get_node(term[2])
-        push!(node_terms[node], term[1]*term[2])
+        push!(node_terms[node], term[1] * term[2])
     end
 
     return node_terms
 end
 
 function _extract_separable_terms(
-    func::JuMP.GenericQuadExpr{<:Number,NodeVariableRef},
-    graph::OptiGraph
+    func::JuMP.GenericQuadExpr{<:Number,NodeVariableRef}, graph::OptiGraph
 )
-    node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericQuadExpr{<:Number,NodeVariableRef}}}()
+    node_terms = OrderedDict{
+        OptiNode,Vector{JuMP.GenericQuadExpr{<:Number,NodeVariableRef}}
+    }()
     nodes = collect_nodes(func)
     nodes = intersect(nodes, all_nodes(graph))
     for node in nodes
@@ -166,12 +167,12 @@ function _extract_separable_terms(
 
     for term in JuMP.quad_terms(func)
         node = get_node(term[2])
-        push!(node_terms[node], term[1]*term[2]*term[3])
+        push!(node_terms[node], term[1] * term[2] * term[3])
     end
 
     for term in JuMP.linear_terms(func)
         node = get_node(term[2])
-        push!(node_terms[node], term[1]*term[2])
+        push!(node_terms[node], term[1] * term[2])
     end
 
     return node_terms
@@ -179,8 +180,7 @@ end
 
 # NOTE: method needs improvement. does not cover all separable cases.
 function _extract_separable_terms(
-    func::JuMP.GenericNonlinearExpr{NodeVariableRef},
-    graph::OptiGraph
+    func::JuMP.GenericNonlinearExpr{NodeVariableRef}, graph::OptiGraph
 )
     node_terms = OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}}()
     nodes = collect_nodes(func)
@@ -195,8 +195,8 @@ function _extract_separable_terms(
 end
 
 function _extract_separable_terms(
-    func::JuMP.GenericNonlinearExpr{NodeVariableRef}, 
-    node_terms::OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}}
+    func::JuMP.GenericNonlinearExpr{NodeVariableRef},
+    node_terms::OrderedDict{OptiNode,Vector{JuMP.GenericNonlinearExpr{NodeVariableRef}}},
 )
     # check for a constant multiplier
     multiplier = 1.0
@@ -210,7 +210,7 @@ function _extract_separable_terms(
     if func.head != :+ && func.head != :-
         var = _first_variable(func)
         node = get_node(var)
-        push!(node_terms[node], multiplier*func)
+        push!(node_terms[node], multiplier * func)
     else
         # check each argument
         for arg in func.args
@@ -222,4 +222,4 @@ function _extract_separable_terms(
     end
 
     return nothing
-end
\ No newline at end of file
+end
diff --git a/test/test_optigraph.jl b/test/test_optigraph.jl
index 97b5de6..e9ac7a6 100644
--- a/test/test_optigraph.jl
+++ b/test/test_optigraph.jl
@@ -331,7 +331,7 @@ function test_variable_constraints()
     @variable(n2, 0 <= x <= 2)
 
     # parameter
-    @variable(n1, p in Parameter(1.0))
+    @variable(n2, p in Parameter(1.0))
 
     @test parameter_value(p) == 1.0
     set_parameter_value(p, 2.0)

From faf295566550a355f20a4086262e3452977e97d5 Mon Sep 17 00:00:00 2001
From: jalving <jhjalving@gmail.com>
Date: Fri, 1 Nov 2024 21:29:52 -0700
Subject: [PATCH 5/6] bump version

---
 Project.toml | 2 +-
 README.md    | 2 ++
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index aad8813..35be9d2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -2,7 +2,7 @@ name = "Plasmo"
 uuid = "d3f7391f-f14a-50cc-bbe4-76a32d1bad3c"
 authors = ["Jordan Jalving <jhjalving@gmail.com>"]
 repo = "https://github.com/plasmo-dev/Plasmo.jl.git"
-version = "0.6.3"
+version = "0.6.4"
 
 [deps]
 DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
diff --git a/README.md b/README.md
index 2400fb7..124be5c 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,8 @@
 [![](https://img.shields.io/badge/docs-dev-blue.svg)](https://plasmo-dev.github.io/Plasmo.jl/dev/)
 [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://plasmo-dev.github.io/Plasmo.jl/stable/)
 [![DOI](https://zenodo.org/badge/96967382.svg)](https://zenodo.org/badge/latestdoi/96967382)
+[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
+
 
 # Plasmo.jl
 

From eefa97c5bffd87ebfef3c105b5bcedf061efc718 Mon Sep 17 00:00:00 2001
From: jalving <jhjalving@gmail.com>
Date: Fri, 1 Nov 2024 21:51:31 -0700
Subject: [PATCH 6/6] add separable tests

---
 test/test_optigraph.jl | 24 ++++++++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/test/test_optigraph.jl b/test/test_optigraph.jl
index e9ac7a6..9f110c7 100644
--- a/test/test_optigraph.jl
+++ b/test/test_optigraph.jl
@@ -52,6 +52,30 @@ function test_simple_graph()
     @test MOIU.state(graph) == MOIU.ATTACHED_OPTIMIZER
     MOIU.drop_optimizer(graph)
     @test MOIU.state(graph) == MOIU.NO_OPTIMIZER
+
+    # test separable
+    @test is_separable(objective_function(graph))
+    sep_terms = extract_separable_terms(objective_function(graph), graph)
+    @test sep_terms[nodes[1]][1] == 1 * nodes[1][:x]
+    @test sep_terms[nodes[2]][1] == 2 * nodes[2][:x]
+
+    @objective(graph, Min, nodes[1][:x]^2 + nodes[2][:x]^2)
+    @test is_separable(objective_function(graph))
+    sep_terms = extract_separable_terms(objective_function(graph), graph)
+    @test sep_terms[nodes[1]][1] == 1 * nodes[1][:x]^2
+    @test sep_terms[nodes[2]][1] == 1 * nodes[2][:x]^2
+
+    @objective(graph, Min, nodes[1][:x]^3 + nodes[2][:x]^3)
+    @test is_separable(objective_function(graph))
+    sep_terms = extract_separable_terms(objective_function(graph), graph)
+    @test sep_terms[nodes[1]][1] isa JuMP.GenericNonlinearExpr{NodeVariableRef}
+    @test sep_terms[nodes[2]][1] isa JuMP.GenericNonlinearExpr{NodeVariableRef}
+
+    @objective(graph, Min, nodes[1][:x]^2 + nodes[2][:x]^2 + nodes[1][:x] * nodes[2][:x])
+    @test is_separable(objective_function(graph)) == false
+
+    @objective(graph, Min, nodes[1][:x]^3 * nodes[2][:x]^2)
+    @test is_separable(objective_function(graph)) == false
 end
 
 function test_direct_moi_graph()