From b415851b14e040aabdbb43429da9d8bf7e24a1c0 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 23 Feb 2023 13:29:20 -0500 Subject: [PATCH 1/6] Add new subsitute_subgraph() graph classes This commit adds a new method, substitute_subgraph(), to the PyGraph and PyDiGraph classes. It is used to replace a subgraph in the graph with an external graph. --- ...-substitute_subgraph-d491479ed931cb79.yaml | 6 + src/digraph.rs | 120 +++++++++++++++++- src/graph.rs | 113 ++++++++++++++++- .../digraph/test_substitute_subgraph.py | 56 ++++++++ .../graph/test_substitute_subgraph.py | 58 +++++++++ 5 files changed, 350 insertions(+), 3 deletions(-) create mode 100644 releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml create mode 100644 tests/rustworkx_tests/digraph/test_substitute_subgraph.py create mode 100644 tests/rustworkx_tests/graph/test_substitute_subgraph.py diff --git a/releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml b/releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml new file mode 100644 index 0000000000..78c3fb8a61 --- /dev/null +++ b/releasenotes/notes/add-substitute_subgraph-d491479ed931cb79.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added new methods, :meth:`.PyDiGraph.subsitute_subgraph` and + :meth:`.PyGraph.substitute_subgraph`, which is used to replace + a subgraph in a graph object with an external graph. diff --git a/src/digraph.rs b/src/digraph.rs index 4839b81ace..62981557b0 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -22,7 +22,7 @@ use std::io::{BufReader, BufWriter}; use std::str; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use rustworkx_core::dictmap::*; @@ -2529,6 +2529,124 @@ impl PyDiGraph { } } + /// Substitute a subgraph in the graph with a different subgraph + /// + /// This is used to replace a subgraph in this graph with another graph. A similar result + /// can be achieved by combining :meth:`.contract_nodes` abd + /// :meth:`.substitute_node_with_subgraph` as it + /// + /// :param list nodes: A list of nodes in this graph representing the subgraph + /// to be removed. + /// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with + /// :param dict input_node_map: The mapping of node indices from ```nodes`` to a node + /// in ``subgraph``. This is used for incoming and outgoing edges into the removed + /// subgraph. This will replace any edges conneted to a node in ``nodes`` with the + /// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this + /// mapping. + /// :param callable edge_weight_map: An optional callable object that when + /// used will receive an edge's weight/data payload from ``subgraph`` and + /// will return an object to use as the weight for a newly created edge + /// after the edge is mapped from ``other``. If not specified the weight + /// from the edge in ``other`` will be copied by reference and used. + /// + /// :returns: A mapping of node indices in ``other`` to the new node index in this graph + /// :rtype: NodeMap + pub fn substitute_subgraph( + &mut self, + py: Python, + nodes: Vec, + other: &PyDiGraph, + input_node_map: HashMap, + edge_weight_map: Option, + ) -> PyResult { + let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut node_map: IndexMap = + IndexMap::with_capacity_and_hasher( + other.graph.node_count(), + ahash::RandomState::default(), + ); + let removed_nodes: HashSet = nodes.iter().map(|n| NodeIndex::new(*n)).collect(); + + let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { + match weight_fn { + Some(weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), + } + }; + for node in nodes { + let index = NodeIndex::new(node); + io_nodes.extend( + self.graph + .edges_directed(index, petgraph::Direction::Incoming) + .filter_map(|edge| { + if !removed_nodes.contains(&edge.source()) { + Some((edge.source(), edge.target(), edge.weight().clone_ref(py))) + } else { + None + } + }), + ); + io_nodes.extend( + self.graph + .edges_directed(index, petgraph::Direction::Outgoing) + .filter_map(|edge| { + if !removed_nodes.contains(&edge.target()) { + Some((edge.source(), edge.target(), edge.weight().clone_ref(py))) + } else { + None + } + }), + ); + self.graph.remove_node(index); + } + for node in other.graph.node_indices() { + let weight = other.graph.node_weight(node).unwrap(); + let new_index = self.graph.add_node(weight.clone_ref(py)); + node_map.insert(node.index(), new_index.index()); + } + for edge in other.graph.edge_references() { + let new_source = node_map[edge.source().index()]; + let new_target = node_map[edge.target().index()]; + self.graph.add_edge( + NodeIndex::new(new_source), + NodeIndex::new(new_target), + weight_map_fn(edge.weight(), &edge_weight_map)?, + ); + } + for edge in io_nodes { + let old_source = edge.0; + let new_source = if removed_nodes.contains(&old_source) { + match input_node_map.get(&old_source.index()) { + Some(new_source) => NodeIndex::new(node_map[new_source]), + None => { + let missing_index = old_source.index(); + return Err(PyIndexError::new_err(format!( + "Input/Output node {missing_index} not found in io_node_map" + ))); + } + } + } else { + old_source + }; + let old_target = edge.1; + let new_target = if removed_nodes.contains(&old_target) { + match input_node_map.get(&old_target.index()) { + Some(new_target) => NodeIndex::new(node_map[new_target]), + None => { + let missing_index = old_target.index(); + return Err(PyIndexError::new_err(format!( + "Input/Output node {missing_index} not found in io_node_map" + ))); + } + } + } else { + old_target + }; + self.graph.add_edge(new_source, new_target, edge.2); + } + Ok(NodeMap { node_map }) + } + /// Return a new PyDiGraph object for an edge induced subgraph of this graph /// /// The induced subgraph contains each edge in `edge_list` and each node diff --git a/src/graph.rs b/src/graph.rs index 022fd80506..aa326ab853 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -20,7 +20,7 @@ use std::io::{BufReader, BufWriter}; use std::str; use hashbrown::{HashMap, HashSet}; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; @@ -36,7 +36,9 @@ use numpy::Complex64; use numpy::PyReadonlyArray2; use super::dot_utils::build_dot; -use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList}; +use super::iterators::{ + EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList, +}; use super::{ find_node_by_weight, merge_duplicates, weight_callable, IsNan, NoEdgeBetweenNodes, NodesRemoved, StablePyGraph, @@ -1715,6 +1717,113 @@ impl PyGraph { out_graph } + /// Substitute a subgraph in the graph with a different subgraph + /// + /// This is used to replace a subgraph in this graph with another graph. A similar result + /// can be achieved by combining :meth:`.contract_nodes` abd + /// :meth:`.substitute_node_with_subgraph` as it + /// + /// :param list nodes: A list of nodes in this graph representing the subgraph + /// to be removed. + /// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with + /// :param dict input_node_map: The mapping of node indices from ```nodes`` to a node + /// in ``subgraph``. This is used for incoming and outgoing edges into the removed + /// subgraph. This will replace any edges conneted to a node in ``nodes`` with the + /// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this + /// mapping. + /// :param callable edge_weight_map: An optional callable object that when + /// used will receive an edge's weight/data payload from ``subgraph`` and + /// will return an object to use as the weight for a newly created edge + /// after the edge is mapped from ``other``. If not specified the weight + /// from the edge in ``other`` will be copied by reference and used. + /// + /// :returns: A mapping of node indices in ``other`` to the new node index in this graph + /// :rtype: NodeMap + pub fn substitute_subgraph( + &mut self, + py: Python, + nodes: Vec, + other: &PyGraph, + input_node_map: HashMap, + edge_weight_map: Option, + ) -> PyResult { + let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut node_map: IndexMap = + IndexMap::with_capacity_and_hasher( + other.graph.node_count(), + ahash::RandomState::default(), + ); + let removed_nodes: HashSet = nodes.iter().map(|n| NodeIndex::new(*n)).collect(); + + let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { + match weight_fn { + Some(weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), + } + }; + for node in nodes { + let index = NodeIndex::new(node); + io_nodes.extend( + self.graph + .edges_directed(index, petgraph::Direction::Outgoing) + .filter_map(|edge| { + if !removed_nodes.contains(&edge.target()) { + Some((edge.source(), edge.target(), edge.weight().clone_ref(py))) + } else { + None + } + }), + ); + self.graph.remove_node(index); + } + for node in other.graph.node_indices() { + let weight = other.graph.node_weight(node).unwrap(); + let new_index = self.graph.add_node(weight.clone_ref(py)); + node_map.insert(node.index(), new_index.index()); + } + for edge in other.graph.edge_references() { + let new_source = node_map[edge.source().index()]; + let new_target = node_map[edge.target().index()]; + self.graph.add_edge( + NodeIndex::new(new_source), + NodeIndex::new(new_target), + weight_map_fn(edge.weight(), &edge_weight_map)?, + ); + } + for edge in io_nodes { + let old_source = edge.0; + let new_source = if removed_nodes.contains(&old_source) { + match input_node_map.get(&old_source.index()) { + Some(new_source) => NodeIndex::new(node_map[new_source]), + None => { + let missing_index = old_source.index(); + return Err(PyIndexError::new_err(format!( + "Input/Output node {missing_index} not found in io_node_map" + ))); + } + } + } else { + old_source + }; + let old_target = edge.1; + let new_target = if removed_nodes.contains(&old_target) { + match input_node_map.get(&old_target.index()) { + Some(new_target) => NodeIndex::new(node_map[new_target]), + None => { + let missing_index = old_target.index(); + return Err(PyIndexError::new_err(format!( + "Input/Output node {missing_index} not found in io_node_map" + ))); + } + } + } else { + old_target + }; + self.graph.add_edge(new_source, new_target, edge.2); + } + Ok(NodeMap { node_map }) + } + /// Return a shallow copy of the graph /// /// All node and edge weight/data payloads in the copy will have a diff --git a/tests/rustworkx_tests/digraph/test_substitute_subgraph.py b/tests/rustworkx_tests/digraph/test_substitute_subgraph.py new file mode 100644 index 0000000000..a51be37de6 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_substitute_subgraph.py @@ -0,0 +1,56 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestSubstitute(unittest.TestCase): + def setUp(self): + super().setUp() + self.graph = rustworkx.generators.directed_path_graph(5) + + def test_empty_replacement(self): + in_graph = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + self.graph.substitute_subgraph([2], in_graph, {}) + + def test_single_node(self): + in_graph = rustworkx.PyDiGraph() + in_graph.add_node(0) + in_graph.add_child(0, 1, "edge") + res = self.graph.substitute_subgraph([2], in_graph, {2: 0}) + self.assertEqual([(0, 1), (2, 5), (1, 2), (3, 4), (2, 3)], self.graph.edge_list()) + self.assertEqual("edge", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_edge_weight_modifier(self): + in_graph = rustworkx.PyDiGraph() + in_graph.add_node(0) + in_graph.add_child(0, 1, "edge") + res = self.graph.substitute_subgraph( + [2], + in_graph, + {2: 0}, + edge_weight_map=lambda edge: edge + "-migrated", + ) + self.assertEqual([(0, 1), (2, 5), (1, 2), (3, 4), (2, 3)], self.graph.edge_list()) + self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_multiple_mapping(self): + graph = rustworkx.generators.directed_star_graph(5) + in_graph = rustworkx.generators.directed_star_graph(3, inward=True) + res = graph.substitute_subgraph([0, 1, 2], in_graph, {0: 0, 1: 1, 2: 2}) + self.assertEqual({0: 2, 1: 1, 2: 0}, res) + expected = [(1, 2), (0, 2), (2, 4), (2, 3)] + self.assertEqual(expected, graph.edge_list()) diff --git a/tests/rustworkx_tests/graph/test_substitute_subgraph.py b/tests/rustworkx_tests/graph/test_substitute_subgraph.py new file mode 100644 index 0000000000..3cb992fca7 --- /dev/null +++ b/tests/rustworkx_tests/graph/test_substitute_subgraph.py @@ -0,0 +1,58 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestSubstitute(unittest.TestCase): + def setUp(self): + super().setUp() + self.graph = rustworkx.generators.path_graph(5) + + def test_empty_replacement(self): + in_graph = rustworkx.PyGraph() + with self.assertRaises(IndexError): + self.graph.substitute_subgraph([2], in_graph, {}) + + def test_single_node(self): + in_graph = rustworkx.PyGraph() + in_graph.add_node(0) + in_graph.add_node(1) + in_graph.add_edge(0, 1, "edge") + res = self.graph.substitute_subgraph([2], in_graph, {2: 0}) + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (2, 1)], self.graph.edge_list()) + self.assertEqual("edge", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_edge_weight_modifier(self): + in_graph = rustworkx.PyGraph() + in_graph.add_node(0) + in_graph.add_node(1) + in_graph.add_edge(0, 1, "edge") + res = self.graph.substitute_subgraph( + [2], + in_graph, + {2: 0}, + edge_weight_map=lambda edge: edge + "-migrated", + ) + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (2, 1)], self.graph.edge_list()) + self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5)) + self.assertEqual(res, {0: 2, 1: 5}) + + def test_multiple_mapping(self): + graph = rustworkx.generators.star_graph(5) + in_graph = rustworkx.generators.path_graph(3) + res = graph.substitute_subgraph([0, 1, 2], in_graph, {0: 0, 1: 1, 2: 2}) + self.assertEqual({0: 2, 1: 1, 2: 0}, res) + expected = [(2, 1), (1, 0), (2, 4), (2, 3)] + self.assertEqual(expected, graph.edge_list()) From 10fe40c5c577630b21f534dd6becd7b4e71fcaad Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 24 Feb 2023 09:39:58 -0500 Subject: [PATCH 2/6] Fix MSRV compatibility --- src/digraph.rs | 6 ++++-- src/graph.rs | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index c78262b414..d0592a4e8c 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2618,7 +2618,8 @@ impl PyDiGraph { None => { let missing_index = old_source.index(); return Err(PyIndexError::new_err(format!( - "Input/Output node {missing_index} not found in io_node_map" + "Input/Output node {} not found in io_node_map", + missing_index ))); } } @@ -2632,7 +2633,8 @@ impl PyDiGraph { None => { let missing_index = old_target.index(); return Err(PyIndexError::new_err(format!( - "Input/Output node {missing_index} not found in io_node_map" + "Input/Output node {} not found in io_node_map", + missing_index ))); } } diff --git a/src/graph.rs b/src/graph.rs index 4490a9c3f2..bf74cfc7a9 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1795,7 +1795,8 @@ impl PyGraph { None => { let missing_index = old_source.index(); return Err(PyIndexError::new_err(format!( - "Input/Output node {missing_index} not found in io_node_map" + "Input/Output node {} not found in io_node_map", + missing_index ))); } } @@ -1809,7 +1810,8 @@ impl PyGraph { None => { let missing_index = old_target.index(); return Err(PyIndexError::new_err(format!( - "Input/Output node {missing_index} not found in io_node_map" + "Input/Output node {} not found in io_node_map", + missing_index ))); } } From 9ab2100d5d53f013aebcc2ce133fafe1517c639a Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 17 Mar 2023 13:13:23 -0400 Subject: [PATCH 3/6] Fix docs build --- src/digraph.rs | 4 ++-- src/graph.rs | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index 2405478e3b..3936a24d0e 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2529,8 +2529,8 @@ impl PyDiGraph { /// Substitute a subgraph in the graph with a different subgraph /// /// This is used to replace a subgraph in this graph with another graph. A similar result - /// can be achieved by combining :meth:`.contract_nodes` abd - /// :meth:`.substitute_node_with_subgraph` as it + /// can be achieved by combining :meth:`~.PyDiGraph.contract_nodes` and + /// :meth:`~.PyDiGraph.substitute_node_with_subgraph`. /// /// :param list nodes: A list of nodes in this graph representing the subgraph /// to be removed. diff --git a/src/graph.rs b/src/graph.rs index 97651fadf3..08c34c1f01 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1716,10 +1716,6 @@ impl PyGraph { /// Substitute a subgraph in the graph with a different subgraph /// - /// This is used to replace a subgraph in this graph with another graph. A similar result - /// can be achieved by combining :meth:`.contract_nodes` abd - /// :meth:`.substitute_node_with_subgraph` as it - /// /// :param list nodes: A list of nodes in this graph representing the subgraph /// to be removed. /// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with From a3449e2559f21d8673ef4a16805f607a0f958b5a Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 24 May 2023 17:25:11 -0400 Subject: [PATCH 4/6] Add option for cycle checking --- src/digraph.rs | 60 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index fcccd6c747..a4c7abcd56 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -227,9 +227,10 @@ impl PyDiGraph { p_index: NodeIndex, c_index: NodeIndex, edge: PyObject, + force: bool, ) -> PyResult { // Only check for cycles if instance attribute is set to true - if self.check_cycle { + if self.check_cycle || force { // Only check for a cycle (by running has_path_connecting) if // the new edge could potentially add a cycle let cycle_check_required = is_cycle_check_required(self, p_index, c_index); @@ -270,11 +271,11 @@ impl PyDiGraph { .collect::>(); for (other_index, edge_index, weight) in edges { if direction { - self._add_edge(node_between_index, index, weight.clone_ref(py))?; - self._add_edge(index, other_index, weight.clone_ref(py))?; + self._add_edge(node_between_index, index, weight.clone_ref(py), false)?; + self._add_edge(index, other_index, weight.clone_ref(py), false)?; } else { - self._add_edge(other_index, index, weight.clone_ref(py))?; - self._add_edge(index, node_between_index, weight.clone_ref(py))?; + self._add_edge(other_index, index, weight.clone_ref(py), false)?; + self._add_edge(index, node_between_index, weight.clone_ref(py), false)?; } self.graph.remove_edge(edge_index); } @@ -1056,7 +1057,7 @@ impl PyDiGraph { } } for (source, target, weight) in edge_list { - self._add_edge(source, target, weight)?; + self._add_edge(source, target, weight, false)?; } self.graph.remove_node(index); self.node_removed = true; @@ -1088,7 +1089,7 @@ impl PyDiGraph { "One of the endpoints of the edge does not exist in graph", )); } - let out_index = self._add_edge(p_index, c_index, edge)?; + let out_index = self._add_edge(p_index, c_index, edge, false)?; Ok(out_index) } @@ -1158,7 +1159,12 @@ impl PyDiGraph { while max_index >= self.node_count() { self.graph.add_node(py.None()); } - self._add_edge(NodeIndex::new(source), NodeIndex::new(target), py.None())?; + self._add_edge( + NodeIndex::new(source), + NodeIndex::new(target), + py.None(), + false, + )?; } Ok(()) } @@ -1183,7 +1189,12 @@ impl PyDiGraph { while max_index >= self.node_count() { self.graph.add_node(py.None()); } - self._add_edge(NodeIndex::new(source), NodeIndex::new(target), weight)?; + self._add_edge( + NodeIndex::new(source), + NodeIndex::new(target), + weight, + false, + )?; } Ok(()) } @@ -2291,7 +2302,7 @@ impl PyDiGraph { let new_p_index = new_node_map.get(&edge.source()).unwrap(); let new_c_index = new_node_map.get(&edge.target()).unwrap(); let weight = weight_transform_callable(py, &edge_map_func, edge.weight())?; - self._add_edge(*new_p_index, *new_c_index, weight)?; + self._add_edge(*new_p_index, *new_c_index, weight, false)?; } // Add edges from map for (this_index, (index, weight)) in node_map.iter() { @@ -2300,6 +2311,7 @@ impl PyDiGraph { NodeIndex::new(*this_index), *new_index, weight.clone_ref(py), + false, )?; } let out_dict = PyDict::new(py); @@ -2405,6 +2417,7 @@ impl PyDiGraph { NodeIndex::new(out_map[&edge.source().index()]), NodeIndex::new(out_map[&edge.target().index()]), weight_map_fn(edge.weight(), &edge_weight_map)?, + false, )?; } // Add edges to/from node to nodes in other @@ -2432,7 +2445,7 @@ impl PyDiGraph { }, None => continue, }; - self._add_edge(source, target_out, weight)?; + self._add_edge(source, target_out, weight, false)?; } for (source, target, weight) in out_edges { let old_index = map_fn(source.index(), target.index(), &weight)?; @@ -2448,7 +2461,7 @@ impl PyDiGraph { }, None => continue, }; - self._add_edge(source_out, target, weight)?; + self._add_edge(source_out, target, weight, false)?; } // Remove node self.graph.remove_node(node_index); @@ -2652,8 +2665,21 @@ impl PyDiGraph { /// after the edge is mapped from ``other``. If not specified the weight /// from the edge in ``other`` will be copied by reference and used. /// + /// :param bool cycle_check: To check and raise if the substitution would introduce a cycle. + /// If set to ``True`` or :attr:`.check_cycle` is set to ``True`` when a cycle would be + /// added a :class:`~.DAGWouldCycle` exception will be raised. However, in this case the + /// state of the graph will be partially through the internal steps required for the + /// substitution. If your intent is to detect and use the graph if a + /// cycle were to be detected, you should make a copy of the graph + /// (see :meth:`.copy`) prior to calling this method so you have a + /// copy of the input graph to use. + /// /// :returns: A mapping of node indices in ``other`` to the new node index in this graph /// :rtype: NodeMap + /// + /// :raises DAGWouldCycle: If ``cycle_check`` or the :attr:`.check_cycle` attribute are set to + /// ``True`` and a cycle were to be introduced by the substitution. + #[pyo3(signature=(nodes, other, input_node_map, edge_weight_map=None, cycle_check=false))] pub fn substitute_subgraph( &mut self, py: Python, @@ -2661,6 +2687,7 @@ impl PyDiGraph { other: &PyDiGraph, input_node_map: HashMap, edge_weight_map: Option, + cycle_check: bool, ) -> PyResult { let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); let mut node_map: IndexMap = @@ -2710,11 +2737,12 @@ impl PyDiGraph { for edge in other.graph.edge_references() { let new_source = node_map[edge.source().index()]; let new_target = node_map[edge.target().index()]; - self.graph.add_edge( + self._add_edge( NodeIndex::new(new_source), NodeIndex::new(new_target), weight_map_fn(edge.weight(), &edge_weight_map)?, - ); + cycle_check, + )?; } for edge in io_nodes { let old_source = edge.0; @@ -2747,7 +2775,7 @@ impl PyDiGraph { } else { old_target }; - self.graph.add_edge(new_source, new_target, edge.2); + self._add_edge(new_source, new_target, edge.2, cycle_check)?; } Ok(NodeMap { node_map }) } @@ -2863,7 +2891,7 @@ impl PyDiGraph { Some(callback) => callback.call1(py, (forward_weight,))?, None => forward_weight.clone_ref(py), }; - self._add_edge(*edge_target, *edge_source, weight)?; + self._add_edge(*edge_target, *edge_source, weight, false)?; } } Ok(()) From 3f90b7276d5109af40ea2121fa7fc1ccad7b076a Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 18 Jan 2024 11:39:39 -0500 Subject: [PATCH 5/6] Fix docs typos --- src/digraph.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index 983aa9f44b..03f530edf2 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2626,10 +2626,10 @@ impl PyDiGraph { /// /// :param list nodes: A list of nodes in this graph representing the subgraph /// to be removed. - /// :param PyDiGraph subgraph: The subgraph to replace ``nodes`` with - /// :param dict input_node_map: The mapping of node indices from ```nodes`` to a node + /// :param PyDiGraph other: The subgraph to replace ``nodes`` with + /// :param dict input_node_map: The mapping of node indices from ``nodes`` to a node /// in ``subgraph``. This is used for incoming and outgoing edges into the removed - /// subgraph. This will replace any edges conneted to a node in ``nodes`` with the + /// subgraph. This will replace any edges connected to a node in ``nodes`` with the /// other endpoint outside ``nodes`` where the node in ``nodes`` replaced via this /// mapping. /// :param callable edge_weight_map: An optional callable object that when @@ -2641,7 +2641,7 @@ impl PyDiGraph { /// :param bool cycle_check: To check and raise if the substitution would introduce a cycle. /// If set to ``True`` or :attr:`.check_cycle` is set to ``True`` when a cycle would be /// added a :class:`~.DAGWouldCycle` exception will be raised. However, in this case the - /// state of the graph will be partially through the internal steps required for the + /// state of the graph will be partially modified through the internal steps required for the /// substitution. If your intent is to detect and use the graph if a /// cycle were to be detected, you should make a copy of the graph /// (see :meth:`.copy`) prior to calling this method so you have a @@ -2651,7 +2651,7 @@ impl PyDiGraph { /// :rtype: NodeMap /// /// :raises DAGWouldCycle: If ``cycle_check`` or the :attr:`.check_cycle` attribute are set to - /// ``True`` and a cycle were to be introduced by the substitution. + /// ``True`` and a cycle would be introduced by the substitution. #[pyo3(signature=(nodes, other, input_node_map, edge_weight_map=None, cycle_check=false))] pub fn substitute_subgraph( &mut self, From aeee114e8693655c3b050db486960fd609829e00 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 18 Jan 2024 11:50:29 -0500 Subject: [PATCH 6/6] Separate io_nodes into 2 vecs --- src/digraph.rs | 54 +++++++++---------- src/graph.rs | 2 - .../digraph/test_substitute_subgraph.py | 5 +- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index 03f530edf2..2d392c5c41 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2662,7 +2662,8 @@ impl PyDiGraph { edge_weight_map: Option, cycle_check: bool, ) -> PyResult { - let mut io_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut in_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); + let mut out_nodes: Vec<(NodeIndex, NodeIndex, PyObject)> = Vec::new(); let mut node_map: IndexMap = IndexMap::with_capacity_and_hasher( other.graph.node_count(), @@ -2678,7 +2679,7 @@ impl PyDiGraph { }; for node in nodes { let index = NodeIndex::new(node); - io_nodes.extend( + in_nodes.extend( self.graph .edges_directed(index, petgraph::Direction::Incoming) .filter_map(|edge| { @@ -2689,7 +2690,7 @@ impl PyDiGraph { } }), ); - io_nodes.extend( + out_nodes.extend( self.graph .edges_directed(index, petgraph::Direction::Outgoing) .filter_map(|edge| { @@ -2717,38 +2718,33 @@ impl PyDiGraph { cycle_check, )?; } - for edge in io_nodes { + for edge in out_nodes { let old_source = edge.0; - let new_source = if removed_nodes.contains(&old_source) { - match input_node_map.get(&old_source.index()) { - Some(new_source) => NodeIndex::new(node_map[new_source]), - None => { - let missing_index = old_source.index(); - return Err(PyIndexError::new_err(format!( - "Input/Output node {} not found in io_node_map", - missing_index - ))); - } + let new_source = match input_node_map.get(&old_source.index()) { + Some(new_source) => NodeIndex::new(node_map[new_source]), + None => { + let missing_index = old_source.index(); + return Err(PyIndexError::new_err(format!( + "Input node {} not found in io_node_map", + missing_index + ))); } - } else { - old_source }; + self._add_edge(new_source, edge.1, edge.2, cycle_check)?; + } + for edge in in_nodes { let old_target = edge.1; - let new_target = if removed_nodes.contains(&old_target) { - match input_node_map.get(&old_target.index()) { - Some(new_target) => NodeIndex::new(node_map[new_target]), - None => { - let missing_index = old_target.index(); - return Err(PyIndexError::new_err(format!( - "Input/Output node {} not found in io_node_map", - missing_index - ))); - } + let new_target = match input_node_map.get(&old_target.index()) { + Some(new_target) => NodeIndex::new(node_map[new_target]), + None => { + let missing_index = old_target.index(); + return Err(PyIndexError::new_err(format!( + "Output node {} not found in io_node_map", + missing_index + ))); } - } else { - old_target }; - self._add_edge(new_source, new_target, edge.2, cycle_check)?; + self._add_edge(edge.0, new_target, edge.2, cycle_check)?; } Ok(NodeMap { node_map }) } diff --git a/src/graph.rs b/src/graph.rs index 9c2224ecec..96fb887504 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -35,8 +35,6 @@ use num_traits::Zero; use numpy::Complex64; use numpy::PyReadonlyArray2; -use crate::iterators::NodeMap; - use super::dot_utils::build_dot; use super::iterators::{ EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, NodeMap, WeightedEdgeList, diff --git a/tests/rustworkx_tests/digraph/test_substitute_subgraph.py b/tests/rustworkx_tests/digraph/test_substitute_subgraph.py index a51be37de6..25004dea2b 100644 --- a/tests/rustworkx_tests/digraph/test_substitute_subgraph.py +++ b/tests/rustworkx_tests/digraph/test_substitute_subgraph.py @@ -29,7 +29,7 @@ def test_single_node(self): in_graph.add_node(0) in_graph.add_child(0, 1, "edge") res = self.graph.substitute_subgraph([2], in_graph, {2: 0}) - self.assertEqual([(0, 1), (2, 5), (1, 2), (3, 4), (2, 3)], self.graph.edge_list()) + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (1, 2)], self.graph.edge_list()) self.assertEqual("edge", self.graph.get_edge_data(2, 5)) self.assertEqual(res, {0: 2, 1: 5}) @@ -43,7 +43,8 @@ def test_edge_weight_modifier(self): {2: 0}, edge_weight_map=lambda edge: edge + "-migrated", ) - self.assertEqual([(0, 1), (2, 5), (1, 2), (3, 4), (2, 3)], self.graph.edge_list()) + + self.assertEqual([(0, 1), (2, 5), (2, 3), (3, 4), (1, 2)], self.graph.edge_list()) self.assertEqual("edge-migrated", self.graph.get_edge_data(2, 5)) self.assertEqual(res, {0: 2, 1: 5})