From 80c7d19b798d9ad1e1c355fb9a4906f94ead372b Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 18 Jun 2024 12:25:10 +0100 Subject: [PATCH] use Counter multiset for edge_set --- pyzx/graph/diff.py | 14 ++++++-------- pyzx/graph/multigraph.py | 5 ++++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pyzx/graph/diff.py b/pyzx/graph/diff.py index b838e6fd..5f22057a 100644 --- a/pyzx/graph/diff.py +++ b/pyzx/graph/diff.py @@ -51,17 +51,15 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None: new_verts = g2.vertex_set() self.removed_verts = list(old_verts - new_verts) self.new_verts = list(new_verts - old_verts) + old_edges = g1.edge_set() + new_edges = g2.edge_set() self.new_edges = [] self.removed_edges = [] - g1_edges = list(g1.edges()) - new_edges = [i for i in g2.edges() if not i in g1_edges or g1_edges.remove(i)] - for e in (new_edges): + for e in (new_edges - old_edges): self.new_edges.append((g2.edge_st(e), g2.edge_type(e))) - g2_edges = list(g2.edges()) - removed_edges = [i for i in g1.edges() if not i in g2_edges or g2_edges.remove(i)] - for e in (removed_edges): + for e in (old_edges - new_edges): s,t = g1.edge_st(e) if s in self.removed_verts or t in self.removed_verts: continue self.removed_edges.append(e) @@ -86,8 +84,8 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None: pos2 = g2.qubit(v), g2.row(v) self.changed_pos[v] = pos2 - for e in g2.edges(): - if e in g1.edges(): + for e in new_edges: + if e in old_edges: if g1.edge_type(e) != g2.edge_type(e): self.changed_edge_types[e] = g2.edge_type(e) else: diff --git a/pyzx/graph/multigraph.py b/pyzx/graph/multigraph.py index 5d0f3a14..7a2541ce 100644 --- a/pyzx/graph/multigraph.py +++ b/pyzx/graph/multigraph.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import Counter from fractions import Fraction from typing import Tuple, Dict, Set, Any @@ -304,8 +305,10 @@ def edges(self, s=None, t=None): def edge(self, s, t): return (s,t) if s < t else (t,s) + def edge_set(self): - return set(self.edges()) + return Counter(self.edges()) + def edge_st(self, edge): return (edge[0], edge[1])