From e4f1eab711e0abb37858d85fd2f32f9b6831eaac Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Tue, 18 Jun 2024 17:18:26 +0200 Subject: [PATCH] Made `is_graph_like` not check for connectivity to boundaries by default, as that is a more sensible option. See #238 --- pyzx/altextract.py | 6 +++--- pyzx/circuit/gates.py | 4 ++-- pyzx/editor.py | 2 +- pyzx/extract.py | 2 +- pyzx/generate.py | 10 ++++----- pyzx/graph/base.py | 20 +++++++++--------- pyzx/graph/diff.py | 6 +++--- pyzx/graph/graph_s.py | 4 ++-- pyzx/graph/jsonparser.py | 8 +++---- pyzx/graph/multigraph.py | 6 +++--- pyzx/hrules.py | 4 ++-- pyzx/io.py | 6 +++--- pyzx/mbqc.py | 4 ++-- pyzx/rules.py | 10 ++++----- pyzx/simplify.py | 32 +++++++++++++++------------- pyzx/tensor.py | 2 +- pyzx/tikz.py | 8 +++---- pyzx/utils.py | 45 ++++++++++++++++++++-------------------- 18 files changed, 91 insertions(+), 88 deletions(-) diff --git a/pyzx/altextract.py b/pyzx/altextract.py index 4f88e5ab..606a0a7d 100644 --- a/pyzx/altextract.py +++ b/pyzx/altextract.py @@ -127,7 +127,7 @@ def alt_extract_circuit( e = g.edge(v,b) if g.edge_type(e) == 2: # Hadamard edge c.add_gate("HAD",q) - g.set_edge_type(e,1) + g.set_edge_type(e, EdgeType.SIMPLE) if phases[v]: c.add_gate("ZPhase", q, phases[v]) g.set_phase(v,0) @@ -174,11 +174,11 @@ def alt_extract_circuit( b = [w for w in d if w in inputs][0] q = qs[b] r = rs[b] - w = g.add_vertex(1,q,r+1) + w = g.add_vertex(VertexType.Z, q, r+1) e = g.edge(v,b) et = g.edge_type(e) g.remove_edge(e) - g.add_edge((v,w),2) + g.add_edge((v, w), EdgeType.HADAMARD) g.add_edge((w,b),toggle_edge(et)) d.remove(b) d.append(w) diff --git a/pyzx/circuit/gates.py b/pyzx/circuit/gates.py index 11a37a4c..d4972e64 100644 --- a/pyzx/circuit/gates.py +++ b/pyzx/circuit/gates.py @@ -293,10 +293,10 @@ def to_graph(self, g: BaseGraph[VT,ET], q_mapper: TargetMapper[VT], c_mapper: Ta def graph_add_node(self, g: BaseGraph[VT,ET], mapper: TargetMapper[VT], - t: VertexType.Type, + t: VertexType, l: int, r: int, phase: FractionLike=0, - etype: EdgeType.Type=EdgeType.SIMPLE, + etype: EdgeType=EdgeType.SIMPLE, ground: bool = False) -> VT: v = g.add_vertex(t, mapper.to_qubit(l), r, phase, ground) g.add_edge((mapper.prev_vertex(l), v), etype) diff --git a/pyzx/editor.py b/pyzx/editor.py index 7d36b63f..2a7e21c5 100644 --- a/pyzx/editor.py +++ b/pyzx/editor.py @@ -117,7 +117,7 @@ def load_js() -> None: """.format(settings.d3_load_string,data1,data2) display(HTML(text)) -def s_to_phase(s: str, t:VertexType.Type=VertexType.Z) -> Fraction: +def s_to_phase(s: str, t:VertexType=VertexType.Z) -> Fraction: if not s: if t!= VertexType.H_BOX: return Fraction(0) else: return Fraction(1) diff --git a/pyzx/extract.py b/pyzx/extract.py index e1cfb0ed..c1de2ec4 100644 --- a/pyzx/extract.py +++ b/pyzx/extract.py @@ -42,7 +42,7 @@ def connectivity_from_biadj( m: Mat2, left:List[VT], right: List[VT], - edgetype:EdgeType.Type=EdgeType.HADAMARD): + edgetype:EdgeType=EdgeType.HADAMARD): """Replace the connectivity in ``g`` between the vertices in ``left`` and ``right`` by the biadjacency matrix ``m``. The edges will be of type ``edgetype``.""" for i in range(len(right)): diff --git a/pyzx/generate.py b/pyzx/generate.py index 349d4a4a..a87dcce8 100644 --- a/pyzx/generate.py +++ b/pyzx/generate.py @@ -70,7 +70,7 @@ def identity(qubits: int, depth: FloatInt=1,backend:Optional[str]=None) -> BaseG return g def spider( - typ:Union[Literal["Z"], Literal["X"], Literal["H"], Literal["W"], Literal["ZBox"], VertexType.Type], + typ:Union[Literal["Z"], Literal["X"], Literal["H"], Literal["W"], Literal["ZBox"], VertexType], inputs: int, outputs: int, phase:Optional[Union[FractionLike, complex]]=None, @@ -176,7 +176,7 @@ def cnots(qubits: int, depth: int, backend:Optional[str]=None) -> BaseGraph: q: List[int] = list(range(qubits)) # qubit index, initialised with input r: int = 1 # current rank - ty: List[VertexType.Type] = [VertexType.BOUNDARY] * qubits # types of vertices + ty: List[VertexType] = [VertexType.BOUNDARY] * qubits # types of vertices qs: List[int] = list(range(qubits)) # tracks qubit indices of vertices rs: List[int] = [0] * qubits # tracks rank of vertices v = qubits # next vertex to add @@ -425,7 +425,7 @@ def cliffords( q = list(range(qubits)) # qubit index, initialised with input r = 1 # current rank - ty: List[VertexType.Type] = [VertexType.BOUNDARY] * qubits # types of vertices + ty: List[VertexType] = [VertexType.BOUNDARY] * qubits # types of vertices qs = list(range(qubits)) # tracks qubit indices of vertices rs = [0] * qubits # tracks rank of vertices v = qubits # next vertex to add @@ -439,7 +439,7 @@ def cliffords( q[i] = v rs.append(r) qs.append(i) - ty.append(1) + ty.append(VertexType.Z) v += 1 r += 1 @@ -454,7 +454,7 @@ def cliffords( ty += [VertexType.Z, VertexType.X] else: es2.append((v,v+1)) - typ: VertexType.Type = random.choice([VertexType.Z, VertexType.X]) + typ: VertexType = random.choice([VertexType.Z, VertexType.X]) ty += [typ, typ] if accept(p_phase): phases[v] = random_phase(t_gates) if accept(p_phase): phases[v+1] = random_phase(t_gates) diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 8f6cfec4..aec61b40 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -267,7 +267,7 @@ def compose(self, other: 'BaseGraph') -> None: if len(outputs) != len(inputs): raise TypeError("Outputs of first graph must match inputs of second.") - plugs: List[Tuple[VT,VT,EdgeType.Type]] = [] + plugs: List[Tuple[VT,VT,EdgeType]] = [] for k in range(len(outputs)): o = outputs[k] i = inputs[k] @@ -664,7 +664,7 @@ def add_vertices(self, amount: int) -> List[VT]: raise NotImplementedError("Not implemented on backend " + type(self).backend) def add_vertex(self, - ty:VertexType.Type=VertexType.BOUNDARY, + ty:VertexType=VertexType.BOUNDARY, qubit:FloatInt=-1, row:FloatInt=-1, phase:Optional[FractionLike]=None, @@ -697,11 +697,11 @@ def add_vertex_indexed(self,v:VT) -> None: which requires vertices to preserve their index.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def add_edges(self, edge_pairs: Iterable[Tuple[VT,VT]], edgetype:EdgeType.Type=EdgeType.SIMPLE) -> None: + def add_edges(self, edge_pairs: Iterable[Tuple[VT,VT]], edgetype:EdgeType=EdgeType.SIMPLE) -> None: """Adds a list of edges to the graph.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def add_edge(self, edge_pair: Tuple[VT,VT], edgetype:EdgeType.Type=EdgeType.SIMPLE) -> ET: + def add_edge(self, edge_pair: Tuple[VT,VT], edgetype:EdgeType=EdgeType.SIMPLE) -> ET: """Adds a single edge of the given type and return its id""" raise NotImplementedError("Not implemented on backend " + type(self).backend) @@ -831,7 +831,7 @@ def edge_set(self) -> Set[ET]: Should be overloaded if the backend supplies a cheaper version than this. Note this ignores parallel edges.""" return set(self.edges()) - def edge(self, s:VT, t:VT, et: EdgeType.Type=EdgeType.SIMPLE) -> ET: + def edge(self, s:VT, t:VT, et: EdgeType=EdgeType.SIMPLE) -> ET: """Returns the name of the first edge with the given source/target and type. Behaviour is undefined if the vertices are not connected.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) @@ -863,26 +863,26 @@ def connected(self,v1: VT,v2: VT) -> bool: """Returns whether vertices v1 and v2 share an edge.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def edge_type(self, e: ET) -> EdgeType.Type: + def edge_type(self, e: ET) -> EdgeType: """Returns the type of the given edge: ``EdgeType.SIMPLE`` if it is regular, ``EdgeType.HADAMARD`` if it is a Hadamard edge, 0 if the edge is not in the graph.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def set_edge_type(self, e: ET, t: EdgeType.Type) -> None: + def set_edge_type(self, e: ET, t: EdgeType) -> None: """Sets the type of the given edge.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def type(self, vertex: VT) -> VertexType.Type: + def type(self, vertex: VT) -> VertexType: """Returns the type of the given vertex: VertexType.BOUNDARY if it is a boundary, VertexType.Z if it is a Z node, VertexType.X if it is a X node, VertexType.H_BOX if it is an H-box.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def types(self) -> Mapping[VT, VertexType.Type]: + def types(self) -> Mapping[VT, VertexType]: """Returns a mapping of vertices to their types.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def set_type(self, vertex: VT, t: VertexType.Type) -> None: + def set_type(self, vertex: VT, t: VertexType) -> None: """Sets the type of the given vertex to t.""" raise NotImplementedError("Not implemented on backend" + type(self).backend) diff --git a/pyzx/graph/diff.py b/pyzx/graph/diff.py index 5f22057a..b337bf1d 100644 --- a/pyzx/graph/diff.py +++ b/pyzx/graph/diff.py @@ -27,9 +27,9 @@ class GraphDiff(Generic[VT, ET]): removed_verts: List[VT] new_verts: List[VT] removed_edges: List[ET] - new_edges: List[Tuple[Tuple[VT,VT],EdgeType.Type]] - changed_vertex_types: Dict[VT,VertexType.Type] - changed_edge_types: Dict[ET, EdgeType.Type] + new_edges: List[Tuple[Tuple[VT,VT],EdgeType]] + changed_vertex_types: Dict[VT,VertexType] + changed_edge_types: Dict[ET, EdgeType] changed_phases: Dict[VT, FractionLike] changed_pos: Dict[VT, Tuple[FloatInt,FloatInt]] changed_vdata: Dict[VT, Any] diff --git a/pyzx/graph/graph_s.py b/pyzx/graph/graph_s.py index a4e47214..759c8105 100644 --- a/pyzx/graph/graph_s.py +++ b/pyzx/graph/graph_s.py @@ -29,10 +29,10 @@ class GraphS(BaseGraph[int,Tuple[int,int]]): #can be found in base.BaseGraph def __init__(self) -> None: BaseGraph.__init__(self) - self.graph: Dict[int,Dict[int,EdgeType.Type]] = dict() + self.graph: Dict[int,Dict[int,EdgeType]] = dict() self._vindex: int = 0 self.nedges: int = 0 - self.ty: Dict[int,VertexType.Type] = dict() + self.ty: Dict[int,VertexType] = dict() self._phase: Dict[int, FractionLike] = dict() self._qindex: Dict[int, FloatInt] = dict() self._maxq: FloatInt = -1 diff --git a/pyzx/graph/jsonparser.py b/pyzx/graph/jsonparser.py index 45d282f2..e2bb6386 100644 --- a/pyzx/graph/jsonparser.py +++ b/pyzx/graph/jsonparser.py @@ -247,11 +247,11 @@ def graph_to_json(g: BaseGraph[VT,ET], include_scalar: bool=True) -> str: i = 0 for e in g.edges(): src,tgt = g.edge_st(e) - t = g.edge_type(e) - if t == EdgeType.SIMPLE: + et = g.edge_type(e) + if et == EdgeType.SIMPLE: edges["e"+ str(i)] = {"src": names[src],"tgt": names[tgt]} i += 1 - elif t == EdgeType.HADAMARD: + elif et == EdgeType.HADAMARD: x1,y1 = g.row(src), -g.qubit(src) x2,y2 = g.row(tgt), -g.qubit(tgt) hadname = freenamesv.pop(0) @@ -261,7 +261,7 @@ def graph_to_json(g: BaseGraph[VT,ET], include_scalar: bool=True) -> str: i += 1 edges["e"+str(i)] = {"src": names[tgt],"tgt": hadname} i += 1 - elif t == EdgeType.W_IO: + elif et == EdgeType.W_IO: edges["e"+str(i)] = {"src": names[src],"tgt": names[tgt], "type": "w_io"} i += 1 else: diff --git a/pyzx/graph/multigraph.py b/pyzx/graph/multigraph.py index 7a2541ce..f3f4102c 100644 --- a/pyzx/graph/multigraph.py +++ b/pyzx/graph/multigraph.py @@ -51,12 +51,12 @@ def remove(self, s: int=0, h: int=0, w_io: int=0): def is_empty(self) -> bool: return self.s == 0 and self.h == 0 and self.w_io == 0 - def get_edge_count(self, ty: EdgeType.Type) -> int: + def get_edge_count(self, ty: EdgeType) -> int: if ty == EdgeType.SIMPLE: return self.s elif ty == EdgeType.HADAMARD: return self.h else: return self.w_io -class Multigraph(BaseGraph[int,Tuple[int,int,EdgeType.Type]]): +class Multigraph(BaseGraph[int,Tuple[int,int,EdgeType]]): """Purely Pythonic multigraph implementation of :class:`~graph.base.BaseGraph`.""" backend = 'multigraph' @@ -68,7 +68,7 @@ def __init__(self) -> None: self._auto_simplify: bool = True self._vindex: int = 0 self.nedges: int = 0 - self.ty: Dict[int,VertexType.Type] = dict() + self.ty: Dict[int,VertexType] = dict() self._phase: Dict[int, FractionLike] = dict() self._qindex: Dict[int, FloatInt] = dict() self._maxq: FloatInt = -1 diff --git a/pyzx/hrules.py b/pyzx/hrules.py index d51c0281..13d36f44 100644 --- a/pyzx/hrules.py +++ b/pyzx/hrules.py @@ -102,7 +102,7 @@ def fuse_hboxes(g: BaseGraph[VT,ET], matches: List[ET]) -> rules.RewriteOutputTy -MatchCopyType = Tuple[VT,VT,VertexType.Type,FractionLike,FractionLike,List[VT]] +MatchCopyType = Tuple[VT,VT,VertexType,FractionLike,FractionLike,List[VT]] def match_copy( g: BaseGraph[VT,ET], @@ -128,7 +128,7 @@ def match_copy( if tw == VertexType.BOUNDARY: continue e = g.edge(v,w) et = g.edge_type(e) - copy_type: VertexType.Type = VertexType.Z + copy_type: VertexType = VertexType.Z if vertex_is_zx(tv): if vertex_is_zx(tw): if et == EdgeType.HADAMARD: diff --git a/pyzx/io.py b/pyzx/io.py index 85543b64..ffb4b344 100644 --- a/pyzx/io.py +++ b/pyzx/io.py @@ -190,11 +190,11 @@ def graph_to_json(g: BaseGraph[VT,ET], force_deprecated_behavior=False) -> str: i = 0 for e in g.edges(): src,tgt = g.edge_st(e) - t = g.edge_type(e) - if t == EdgeType.SIMPLE: + et = g.edge_type(e) + if et == EdgeType.SIMPLE: edges["e"+ str(i)] = {"src": names[src],"tgt": names[tgt]} i += 1 - elif t==EdgeType.HADAMARD: + elif et==EdgeType.HADAMARD: x1,y1 = g.row(src), -g.qubit(src) x2,y2 = g.row(tgt), -g.qubit(tgt) hadname = freenamesv.pop(0) diff --git a/pyzx/mbqc.py b/pyzx/mbqc.py index c49a04ab..f35bbeb8 100644 --- a/pyzx/mbqc.py +++ b/pyzx/mbqc.py @@ -31,7 +31,7 @@ def cluster_state(m: int, n: int, inputs: List[Tuple[int,int]]=[]) -> BaseGraph: g.set_outputs(tuple(outp)) return g -def measure(g:BaseGraph, pos:Tuple[int,int], t:VertexType.Type=VertexType.Z, phase:FractionLike=0): +def measure(g:BaseGraph, pos:Tuple[int,int], t:VertexType=VertexType.Z, phase:FractionLike=0): """Measure the qubit at the given grid position, basis, and phase.""" q = 2*pos[0]-0.8 r = 2*pos[1]+0.8 @@ -53,7 +53,7 @@ def measure(g:BaseGraph, pos:Tuple[int,int], t:VertexType.Type=VertexType.Z, pha if not found: raise ValueError("Couldn't find a qubit at that position") -def apply_pauli(g:BaseGraph, pos:Tuple[int,int], t:VertexType.Type=VertexType.Z, phase:FractionLike=1): +def apply_pauli(g:BaseGraph, pos:Tuple[int,int], t:VertexType=VertexType.Z, phase:FractionLike=1): """Measure the qubit at the given grid position, basis, and phase.""" if phase == 0: diff --git a/pyzx/rules.py b/pyzx/rules.py index 9767e0b3..1b396a5e 100644 --- a/pyzx/rules.py +++ b/pyzx/rules.py @@ -227,7 +227,7 @@ def spider(g: BaseGraph[VT,ET], matches: List[MatchSpiderType[VT]]) -> RewriteOu if v0 == w: continue e = (v0,w) if e not in etab: etab[e] = [0,0] - etab[e][g.edge_type(g.edge(v1,w))-1] += 1 + etab[e][g.edge_type(g.edge(v1,w)).value - 1] += 1 return (etab, rem_verts, [], True) def unspider(g: BaseGraph[VT,ET], m: List[Any], qubit:FloatInt=-1, row:FloatInt=-1) -> VT: @@ -363,7 +363,7 @@ def w_fusion(g: BaseGraph[VT,ET], matches: List[MatchSpiderType[VT]]) -> Rewrite w = v0_out e = (v0_out, w) if e not in etab: etab[e] = [0,0] - etab[e][g.edge_type(g.edge(v1_out, w)) - 1] += 1 + etab[e][g.edge_type(g.edge(v1_out, w)).value - 1] += 1 return (etab, rem_verts, [], True) @@ -749,7 +749,7 @@ def lcomp(g: BaseGraph[VT,ET], matches: List[MatchLcompType[VT]]) -> RewriteOutp return (etab, rem, [], True) -MatchIdType = Tuple[VT,VT,VT,EdgeType.Type] +MatchIdType = Tuple[VT,VT,VT,EdgeType] def match_ids(g: BaseGraph[VT,ET]) -> List[MatchIdType[VT]]: """Finds a single identity node. See :func:`match_ids_parallel`.""" @@ -1100,8 +1100,8 @@ def apply_gadget_phasepoly(g: BaseGraph[VT,ET], matches: List[MatchPhasePolyType phase = -phase g.set_phase(n,0) else: - n = g.add_vertex(1,-1, rs[group[0]]+0.5) - v = g.add_vertex(1,-2, rs[group[0]]+0.5) + n = g.add_vertex(VertexType.Z, -1, rs[group[0]]+0.5) + v = g.add_vertex(VertexType.Z, -2, rs[group[0]]+0.5) phase = 0 g.add_edges([(n,v)]+[(n,w) for w in group],EdgeType.HADAMARD) g.set_phase(v, phase + Fraction(7,4)) diff --git a/pyzx/simplify.py b/pyzx/simplify.py index 52b75aed..a0cda7f8 100644 --- a/pyzx/simplify.py +++ b/pyzx/simplify.py @@ -407,8 +407,11 @@ def full_reduce_iter(g: BaseGraph[VT,ET]) -> Iterator[Tuple[BaseGraph[VT,ET],str ok = True yield g, f"pivot_gadget -> {step}" -def is_graph_like(g: BaseGraph[VT,ET]) -> bool: - """Checks if a ZX-diagram is graph-like.""" +def is_graph_like(g: BaseGraph[VT,ET], strict:bool=False) -> bool: + """Checks if a ZX-diagram is graph-like: + only contains Z-spiders which are connected by Hadamard edges. + If `strict` is True, then also checks that each boundary vertex is connected to a Z-spider, + and that each Z-spider is connected to at most one boundary.""" # checks that all spiders are Z-spiders for v in g.vertices(): @@ -431,18 +434,19 @@ def is_graph_like(g: BaseGraph[VT,ET]) -> bool: if g.connected(v, v): return False - # every I/O is connected to a Z-spider - bs = [v for v in g.vertices() if g.type(v) == VertexType.BOUNDARY] - for b in bs: - if g.vertex_degree(b) != 1 or g.type(list(g.neighbors(b))[0]) != VertexType.Z: - return False + if strict: + # every I/O is connected to a Z-spider + bs = [v for v in g.vertices() if g.type(v) == VertexType.BOUNDARY] + for b in bs: + if g.vertex_degree(b) != 1 or g.type(list(g.neighbors(b))[0]) != VertexType.Z: + return False - # every Z-spider is connected to at most one I/O - zs = [v for v in g.vertices() if g.type(v) == VertexType.Z] - for z in zs: - b_neighbors = [n for n in g.neighbors(z) if g.type(n) == VertexType.BOUNDARY] - if len(b_neighbors) > 1: - return False + # every Z-spider is connected to at most one I/O + zs = [v for v in g.vertices() if g.type(v) == VertexType.Z] + for z in zs: + b_neighbors = [n for n in g.neighbors(z) if g.type(n) == VertexType.BOUNDARY] + if len(b_neighbors) > 1: + return False return True @@ -510,7 +514,7 @@ def to_graph_like(g: BaseGraph[VT,ET]) -> None: g.add_edge((b,z),EdgeType.SIMPLE) g.add_edge((z,v),EdgeType.HADAMARD) - assert(is_graph_like(g)) + assert(is_graph_like(g,strict=True)) def to_clifford_normal_form_graph(g: BaseGraph[VT,ET]) -> None: """Converts a graph that is Clifford into the form described by the right-hand side of eq. (11) of diff --git a/pyzx/tensor.py b/pyzx/tensor.py index 1e931fc9..ef1276c0 100644 --- a/pyzx/tensor.py +++ b/pyzx/tensor.py @@ -163,7 +163,7 @@ def tensorfy(g: 'BaseGraph[VT,ET]', preserve_scalar:bool=True) -> np.ndarray: raise ValueError("Vertex %s has non-ZXH type but is not an input or output" % str(v)) nn = list(filter(lambda n: rows[n] bool: + BOUNDARY = 0 + Z = 1 + X = 2 + H_BOX = 3 + W_INPUT = 4 + W_OUTPUT = 5 + Z_BOX = 6 + +def vertex_is_zx(ty: VertexType) -> bool: """Check if a vertex type corresponds to a green or red spider.""" return ty in (VertexType.Z, VertexType.X) -def toggle_vertex(ty: VertexType.Type) -> VertexType.Type: +def toggle_vertex(ty: VertexType) -> VertexType: """Swap the X and Z vertex types.""" if not vertex_is_zx(ty): return ty return VertexType.Z if ty == VertexType.X else VertexType.X -def vertex_is_z_like(ty: VertexType.Type) -> bool: +def vertex_is_z_like(ty: VertexType) -> bool: """Check if a vertex type corresponds to a Z spider or Z box.""" return ty == VertexType.Z or ty == VertexType.Z_BOX -def vertex_is_zx_like(ty: VertexType.Type) -> bool: +def vertex_is_zx_like(ty: VertexType) -> bool: """Check if a vertex type corresponds to a Z or X spider or Z box.""" return vertex_is_z_like(ty) or ty == VertexType.X -def vertex_is_w(ty: VertexType.Type) -> bool: +def vertex_is_w(ty: VertexType) -> bool: return ty == VertexType.W_INPUT or ty == VertexType.W_OUTPUT def get_w_partner(g, v): @@ -73,24 +73,23 @@ def get_w_io(g, v): return v2, v -class EdgeType: +class EdgeType(IntEnum): """Type of an edge in the graph.""" - Type = Literal[1, 2, 3] - SIMPLE: Final = 1 - HADAMARD: Final = 2 - W_IO: Final = 3 + SIMPLE = 1 + HADAMARD = 2 + W_IO = 3 -def toggle_edge(ty: EdgeType.Type) -> EdgeType.Type: +def toggle_edge(ty: EdgeType) -> EdgeType: """Swap the regular and Hadamard edge types.""" return EdgeType.HADAMARD if ty == EdgeType.SIMPLE else EdgeType.SIMPLE -def phase_to_s(a: FractionLike, t:VertexType.Type=VertexType.Z) -> str: +def phase_to_s(a: FractionLike, t:VertexType=VertexType.Z) -> str: if isinstance(a, Fraction) or isinstance(a, int): return phase_fraction_to_s(a, t) else: # a is a Poly return str(a) -def phase_fraction_to_s(a: FractionLike, t:VertexType.Type=VertexType.Z) -> str: +def phase_fraction_to_s(a: FractionLike, t:VertexType=VertexType.Z) -> str: if (a == 0 and t != VertexType.H_BOX): return '' if (a == 1 and t == VertexType.H_BOX): return '' if isinstance(a, Poly):