Skip to content

Commit

Permalink
use Counter multiset for edge_set
Browse files Browse the repository at this point in the history
  • Loading branch information
RazinShaikh committed Jun 18, 2024
1 parent 8ba2230 commit 80c7d19
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
14 changes: 6 additions & 8 deletions pyzx/graph/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,15 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:
new_verts = g2.vertex_set()
self.removed_verts = list(old_verts - new_verts)
self.new_verts = list(new_verts - old_verts)
old_edges = g1.edge_set()
new_edges = g2.edge_set()
self.new_edges = []
self.removed_edges = []

g1_edges = list(g1.edges())
new_edges = [i for i in g2.edges() if not i in g1_edges or g1_edges.remove(i)]
for e in (new_edges):
for e in (new_edges - old_edges):
self.new_edges.append((g2.edge_st(e), g2.edge_type(e)))

g2_edges = list(g2.edges())
removed_edges = [i for i in g1.edges() if not i in g2_edges or g2_edges.remove(i)]
for e in (removed_edges):
for e in (old_edges - new_edges):
s,t = g1.edge_st(e)
if s in self.removed_verts or t in self.removed_verts: continue
self.removed_edges.append(e)
Expand All @@ -86,8 +84,8 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:
pos2 = g2.qubit(v), g2.row(v)
self.changed_pos[v] = pos2

for e in g2.edges():
if e in g1.edges():
for e in new_edges:
if e in old_edges:
if g1.edge_type(e) != g2.edge_type(e):
self.changed_edge_types[e] = g2.edge_type(e)
else:
Expand Down
5 changes: 4 additions & 1 deletion pyzx/graph/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import Counter
from fractions import Fraction
from typing import Tuple, Dict, Set, Any

Expand Down Expand Up @@ -304,8 +305,10 @@ def edges(self, s=None, t=None):

def edge(self, s, t):
return (s,t) if s < t else (t,s)

def edge_set(self):
return set(self.edges())
return Counter(self.edges())

def edge_st(self, edge):
return (edge[0], edge[1])

Expand Down

0 comments on commit 80c7d19

Please sign in to comment.