diff --git a/quick/circuit/circuit.py b/quick/circuit/circuit.py index 085518c..4fdd5ee 100644 --- a/quick/circuit/circuit.py +++ b/quick/circuit/circuit.py @@ -252,7 +252,7 @@ def _validate_qubit_index( """ if name in ALL_QUBIT_KEYS: if isinstance(value, list): - # For simplicity, we consider the [i] index to be just 1 (int instead of list) + # For simplicity, we consider the [i] index to be just i (int instead of list) if len(value) == 1: value = self._process_single_qubit_index(value[0]) else: @@ -285,7 +285,10 @@ def _validate_single_angle( - Angle must be a number. """ if not isinstance(angle, (int, float)): - raise TypeError(f"Angle must be a number. Unexpected type {type(angle)} received.") + raise TypeError( + "Angle must be a number. " + f"Received {type(angle)} instead." + ) if abs(angle) <= EPSILON or abs(angle % PI_DOUBLE) <= EPSILON: angle = 0.0 @@ -395,7 +398,7 @@ def process_gate_params( params[name] = value - if sorted(list(set(qubit_indices))) != sorted(qubit_indices): + if len(set(qubit_indices)) != len(qubit_indices): raise ValueError( "Qubit indices must be unique. " f"Received {qubit_indices} instead." @@ -5729,23 +5732,20 @@ def compress( # Define angle closeness threshold threshold = PI * compression_percentage - # Initialize a list for the indices that will be removed - indices_to_remove = [] + new_circuit_log = [] # Iterate over all angles, and set the angles within the # compression percentage to 0 (this means the gate does nothing, and can be removed) - for index, operation in enumerate(self.circuit_log): + for operation in self.circuit_log: if "angle" in operation: if abs(operation["angle"]) < threshold: - indices_to_remove.append(index) + continue elif "angles" in operation: if all([abs(angle) < threshold for angle in operation["angles"]]): - indices_to_remove.append(index) - - # Remove the operations with angles within the compression percentage - for index in sorted(indices_to_remove, reverse=True): - del self.circuit_log[index] + continue + new_circuit_log.append(operation) + self.circuit_log = new_circuit_log self.update() def change_mapping( @@ -5771,24 +5771,18 @@ def change_mapping( ----- >>> circuit.change_mapping(qubit_indices=[1, 0]) """ - if not all(isinstance(index, int) for index in qubit_indices): - raise TypeError("Qubit indices must be a collection of integers.") - - if sorted(list(set(qubit_indices))) != list(range(self.num_qubits)): + if len(set(qubit_indices)) != self.num_qubits: raise ValueError("Qubit indices must be unique.") - if isinstance(qubit_indices, Sequence): - qubit_indices = list(qubit_indices) - elif isinstance(qubit_indices, np.ndarray): - qubit_indices = qubit_indices.tolist() + if any(not isinstance(qubit_index, int) for qubit_index in qubit_indices): + raise TypeError("Qubit indices must all be integers.") if self.num_qubits != len(qubit_indices): raise ValueError("The number of qubits must match the number of qubits in the circuit.") - # Update the qubit indices for operation in self.circuit_log: for key in set(operation.keys()).intersection(ALL_QUBIT_KEYS): - if isinstance(operation[key], list): + if isinstance(operation[key], Sequence): operation[key] = [qubit_indices[index] for index in operation[key]] else: operation[key] = qubit_indices[operation[key]] @@ -5829,7 +5823,7 @@ def convert( # Iterate over the gate log and apply corresponding gates in the new framework for gate_info in self.circuit_log: # Extract gate name and remove it from gate_info for kwargs - gate_name = gate_info.pop("gate", None) + gate_name = gate_info.pop("gate") # Extract gate definition and remove it from gate_info for kwargs gate_definition = gate_info.pop("definition", None) @@ -6328,8 +6322,58 @@ def __eq__( >>> circuit1 == circuit2 """ if not isinstance(other_circuit, Circuit): - raise TypeError("Circuits must be compared with other circuits.") - return self.circuit_log == other_circuit.circuit_log + raise TypeError( + "Circuits can only be compared with other Circuits. " + f"Received {type(other_circuit)} instead." + ) + return self.get_dag() == other_circuit.get_dag() + + def is_equivalent( + self, + other_circuit: Circuit, + check_unitary: bool=True, + check_dag: bool=False + ) -> bool: + """ Check if the circuit is equivalent to another circuit. + + Parameters + ---------- + `other_circuit` : quick.circuit.Circuit + The other circuit to compare to. + `check_unitary` : bool, optional, default=True + Whether or not to check the unitary of the circuit. + `check_dag` : bool, optional, default=False + Whether or not to check the DAG of the circuit. + + Returns + ------- + bool + Whether the two circuits are equivalent. + + Raises + ------ + TypeError + - Circuits must be compared with other circuits. + + Usage + ----- + >>> circuit1.is_equivalent(circuit2) + """ + if not isinstance(other_circuit, Circuit): + raise TypeError( + "Circuits can only be compared with other Circuits. " + f"Received {type(other_circuit)} instead." + ) + + if check_unitary: + if not np.allclose(self.get_unitary(), other_circuit.get_unitary()): + return False + + if check_dag: + if self.get_dag() != other_circuit.get_dag(): + return False + + return True def __len__(self) -> int: """ Get the number of the circuit operations. diff --git a/quick/circuit/dag/dagcircuit.py b/quick/circuit/dag/dagcircuit.py index b4fff2e..7b35520 100644 --- a/quick/circuit/dag/dagcircuit.py +++ b/quick/circuit/dag/dagcircuit.py @@ -78,7 +78,33 @@ def add_operation( """ from quick.circuit.circuit import ALL_QUBIT_KEYS - gate_node = DAGNode(operation["gate"]) + # Extract the gate parameters + params = dict() + meta_params = dict() + + # Log all parameters except 'definition' and 'gate' + # in the meta params to ensure uniqueness of the node + # when the gate is repeated on multiple qubits in + # parallel, i.e., a CX followed by two H on control + # and target + # This avoids `__eq__` issues with DAGNode + for key in operation: + if key not in ALL_QUBIT_KEYS.union(['definition', 'gate']): + params[key] = operation[key] + if key not in ['definition', 'gate']: + meta_params[key] = operation[key] + + # Define the name of the node + # For simplicity, we omit the empty params for gates + # that only have qubit indices as parameter + if params == {}: + node_name = f"{operation['gate']}" + else: + node_name = f"{operation['gate']}({str(params).strip('{}')})" + + meta_name = f"{operation['gate']}({str(meta_params).strip('{}')})" + + gate_node = DAGNode(node_name, meta_name=meta_name) qubit_indices = [] # Add qubits from any valid qubit key to the @@ -108,6 +134,48 @@ def get_depth(self) -> int: """ return max(qubit.depth for qubit in self.qubits.values()) + def __eq__( + self, + other_circuit: object + ) -> bool: + """ Check if two circuits are equal. + + Parameters + ---------- + `other_circuit` : object + The other circuit to compare with. + + Returns + ------- + bool + True if the two circuits are equal, False otherwise. + + Raises + ------ + TypeError + If the other circuit is not an instance of `quick.circuit.dag.DAGCircuit`. + + Usage + ----- + >>> dag1 = DAGCircuit(2) + >>> dag2 = DAGCircuit(2) + >>> dag1 == dag2 + """ + if not isinstance(other_circuit, DAGCircuit): + raise TypeError( + "DAGCircuits can only be compared with other DAGCircuits. " + f"Received {type(other_circuit)} instead." + ) + + if self.num_qubits != other_circuit.num_qubits: + return False + + for qubit in self.qubits: + if self.qubits[qubit] != other_circuit.qubits[qubit]: + return False + + return True + def __repr__(self) -> str: """ Get the string representation of the circuit. diff --git a/quick/circuit/dag/dagnode.py b/quick/circuit/dag/dagnode.py index 3cec2ed..fda1fd4 100644 --- a/quick/circuit/dag/dagnode.py +++ b/quick/circuit/dag/dagnode.py @@ -45,6 +45,9 @@ class DAGNode: ---------- `name` : str The name of the node. + `meta_name` : str, optional, default=name + The meta name of the node. This is used to store additional information + and prevents identical nodes which can lead to issues with `__eq__`. `parents` : set[quick.circuit.dag.DAGNode], optional, default=set() A set of parent nodes. `children` : set[quick.circuit.dag.DAGNode], optional, default=set() @@ -55,9 +58,20 @@ class DAGNode: >>> node1 = DAGNode("Node 1") """ name: Hashable = None + meta_name: Hashable = None parents: set[DAGNode] = field(default_factory=set) children: set[DAGNode] = field(default_factory=set) + def __post_init__(self) -> None: + """ Initialize the node. + + Notes + ----- + This method is called after the node is initialized. We set the + `meta_name` attribute to the `name` attribute if it is not provided. + """ + self.meta_name = self.name if self.meta_name is None else self.meta_name + def _invalidate_depth(self) -> None: """ Invalidate the cached depth of the node. @@ -125,6 +139,26 @@ def to( if hasattr(self, "_depth"): self._invalidate_depth() + def walk(self): + """ Walk through the children nodes of the current node. + + Yields + ------ + `child` : quick.circuit.dag.DAGNode + The next child node. + + Usage + ----- + >>> node1 = DAGNode("Node 1") + >>> node2 = DAGNode("Node 2") + >>> node1.to(node2) + >>> for child in node1.walk(): + ... print(child) + """ + for child in sorted(self.children): + yield child + yield from child.walk() + @property def depth(self) -> int: """ Get the depth of the node. @@ -268,15 +302,51 @@ def __hash__(self) -> int: """ return hash(id(self)) + def __lt__( + self, + other_node: object + ) -> bool: + """ Check if this node is less than or equal to another node. + + Parameters + ---------- + `other_node` : object + The object to compare to. + + Returns + ------- + bool + True if this node is less than or equal to the other node, + False otherwise. + + Raises + ------ + TypeError + - If `other_node` is not an instance of `quick.circuit.dag.DAGNode`. + + Usage + ----- + >>> node1 = DAGNode("Node 1") + >>> node2 = DAGNode("Node 2") + >>> node1 < node2 + """ + if not isinstance(other_node, DAGNode): + raise TypeError( + "The `other_node` must be an instance of DAGNode. " + f"Received {type(other_node)} instead." + ) + + return str(self.meta_name) < str(other_node.meta_name) + def __eq__( self, - other: object + other_node: object ) -> bool: """ Check if two nodes are equal. Parameters ---------- - `other` : object + `other_node` : object The object to compare to. Returns @@ -284,16 +354,55 @@ def __eq__( bool True if the nodes are equal, False otherwise. + Raises + ------ + TypeError + - If `other_node` is not an instance of `quick.circuit.dag.DAGNode`. + Usage ----- >>> node1 = DAGNode("Node 1") >>> node2 = DAGNode("Node 2") >>> node1 == node2 """ - if not isinstance(other, DAGNode): + if not isinstance(other_node, DAGNode): + raise TypeError( + "The `other_node` must be an instance of DAGNode. " + f"Received {type(other_node)} instead." + ) + + if self.meta_name != other_node.meta_name: + return False + + inclusion_check: dict[DAGNode, int] = {} + self_stack: list[DAGNode] = [self] + other_stack: list[DAGNode] = [other_node] + + while self_stack and other_stack: + self_node = self_stack.pop() + other_node = other_stack.pop() + + if self_node in inclusion_check and other_node in inclusion_check: + continue + + # Set the inclusion check of the nodes to zero + # (this value is arbitrary, we just need a placeholder) + inclusion_check[self_node] = 0 + inclusion_check[other_node] = 0 + + for self_child, other_child in zip(sorted(self_node.children), sorted(other_node.children)): + self_stack.append(self_child) + other_stack.append(other_child) + + if self_child.meta_name != other_child.meta_name: + return False + + # If two DAGs are equal up until the end, but one has additional + # nodes afterwards it will be caught here + if self_stack or other_stack: return False - return self.name == other.name and self.children == other.children + return True def __repr__(self) -> str: """ Get the string representation of the node. diff --git a/stubs/quick/circuit/circuit.pyi b/stubs/quick/circuit/circuit.pyi index 26ff27f..c20d8ae 100644 --- a/stubs/quick/circuit/circuit.pyi +++ b/stubs/quick/circuit/circuit.pyi @@ -334,6 +334,12 @@ class Circuit(ABC, metaclass=abc.ABCMeta): def __getitem__(self, index: int | slice) -> Circuit: ... def __setitem__(self, index: int | slice, gates: dict | list[dict] | Circuit) -> None: ... def __eq__(self, other_circuit: object) -> bool: ... + def is_equivalent( + self, + other_circuit: Circuit, + check_unitary: bool=True, + check_dag: bool=False + ) -> bool: ... def __len__(self) -> int: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... diff --git a/stubs/quick/circuit/dag/dagnode.pyi b/stubs/quick/circuit/dag/dagnode.pyi index 16743e1..87029ab 100644 --- a/stubs/quick/circuit/dag/dagnode.pyi +++ b/stubs/quick/circuit/dag/dagnode.pyi @@ -12,20 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Hashable __all__ = ["DAGNode"] @dataclass class DAGNode: - name: Hashable = ... - parents: set[DAGNode] = ... - children: set[DAGNode] = ... + name: Hashable = None + meta_name: Hashable = None + parents: set[DAGNode] = field(default_factory=set) + children: set[DAGNode] = field(default_factory=set) + def __post_init__(self) -> None: ... def to(self, child: DAGNode) -> None: ... @property def depth(self) -> int: ... def generate_paths(self) -> set[tuple[Hashable]]: ... def __hash__(self) -> int: ... + def __lt__(self, other: object) -> bool: ... def __eq__(self, other: object) -> bool: ... def __init__(self, name=..., parents=..., children=...) -> None: ... diff --git a/tests/circuit/dag/test_dagcircuit.py b/tests/circuit/dag/test_dagcircuit.py index 72c2da8..8d847c6 100644 --- a/tests/circuit/dag/test_dagcircuit.py +++ b/tests/circuit/dag/test_dagcircuit.py @@ -38,6 +38,11 @@ def test_add_operation(self) -> None: assert repr(circuit) == "\n".join(["Q0: Q0 -> {H -> {CX}}", "Q1: Q1 -> {CX}"]) + circuit = DAGCircuit(1) + circuit.add_operation({"gate": "RX", "angle": 0.1, "qubit_indices": 0}) + + assert repr(circuit) == "Q0: Q0 -> {RX('angle': 0.1)}" + def test_get_depth(self) -> None: """ Test the `get_depth` method of a `DAGCircuit` object. """ diff --git a/tests/circuit/dag/test_dagnode.py b/tests/circuit/dag/test_dagnode.py index 0cbd2ed..58e4753 100644 --- a/tests/circuit/dag/test_dagnode.py +++ b/tests/circuit/dag/test_dagnode.py @@ -29,6 +29,7 @@ def test_init(self) -> None: """ dagnode = DAGNode("test_node") assert dagnode.name == "test_node" + assert dagnode.meta_name == "test_node" assert dagnode.children == set() assert dagnode.parents == set() @@ -96,6 +97,52 @@ def test_to_invalid(self) -> None: with pytest.raises(TypeError): dagnode1.to(dagnode2) # type: ignore + def test_lt(self) -> None: + """ Test the less than comparison of two `DAGNode` objects. + """ + dagnode1 = DAGNode("node1") + dagnode2 = DAGNode("node2") + + assert dagnode1 < dagnode2 + + def test_lt_invalid(self) -> None: + """ Test the less than comparison of a `DAGNode` object with an invalid argument. + """ + dagnode = DAGNode("node1") + invalid = "node2" + + with pytest.raises(TypeError): + dagnode < invalid # type: ignore + + def test_eq(self) -> None: + """ Test the equality of two `DAGNode` objects. + """ + dagnode1 = DAGNode("node1") + dagnode2 = DAGNode("node2") + dagnode3 = DAGNode("node1") + + assert dagnode1 == dagnode3 + assert dagnode1 != dagnode2 + + node_a = DAGNode("A") + node_b = DAGNode("B") + node_a2 = DAGNode("A") + node_b2 = DAGNode("B") + + node_a.to(node_b) + node_a2.to(node_b2) + + assert node_a == node_a2 + + def test_eq_invalid(self) -> None: + """ Test the equality of a `DAGNode` object with an invalid argument. + """ + dagnode = DAGNode("node1") + invalid = "node2" + + with pytest.raises(TypeError): + dagnode == invalid # type: ignore + def test_str(self) -> None: """ Test the string representation of a `DAGNode` object. """ diff --git a/tests/circuit/test_circuit_base.py b/tests/circuit/test_circuit_base.py index 3ee9882..c9ffad1 100644 --- a/tests/circuit/test_circuit_base.py +++ b/tests/circuit/test_circuit_base.py @@ -866,7 +866,7 @@ def test_change_mapping_indices_value_error( circuit.change_mapping([0, 1, 2]) @pytest.mark.parametrize("circuit_framework", CIRCUIT_FRAMEWORKS) - def test_from_circuit( + def test_convert( self, circuit_framework: Type[Circuit] ) -> None: @@ -1225,6 +1225,115 @@ def test_eq( for circuit_1, circuit_2 in zip(circuits[0:-1:], circuits[1::]): assert circuit_1 == circuit_2 + # Test the equality of circuits when the order + # of the gates are different but the circuit is the same + circuit_1 = [circuit_framework(2) for circuit_framework in circuit_frameworks] + for circuit in circuit_1: + circuit.H(0) + circuit.X(1) + + circuit_2 = [circuit_framework(2) for circuit_framework in circuit_frameworks] + for circuit in circuit_2: + circuit.X(1) + circuit.H(0) + + # Test the equality of the circuits + for circuit_1, circuit_2 in zip(circuit_1, circuit_2): + assert circuit_1 == circuit_2 + + @pytest.mark.parametrize("circuit_framework", CIRCUIT_FRAMEWORKS) + def test_eq_fail( + self, + circuit_framework: Type[Circuit] + ) -> None: + """ Test the `__eq__` dunder method failure. + + Parameters + ---------- + `circuit_framework`: type[quick.circuit.Circuit] + The circuit framework to test. + """ + circuit_1 = circuit_framework(2) + circuit_2 = circuit_framework(3) + + assert not circuit_1 == circuit_2 + + circuit_1 = circuit_framework(2) + circuit_2 = circuit_framework(2) + + circuit_1.H(0) + circuit_2.X(0) + + assert not circuit_1 == circuit_2 + + circuit_1 = circuit_framework(2) + circuit_2 = circuit_framework(2) + + circuit_1.H(0) + circuit_2.H(0) + + circuit_1.CX(0, 1) + circuit_2.CX(0, 1) + + circuit_1.H(0) + circuit_2.H(1) + + assert not circuit_1 == circuit_2 + + circuit_1 = circuit_framework(2) + circuit_2 = circuit_framework(2) + + circuit_1.CX(0, 1) + circuit_2.CX(1, 0) + + assert not circuit_1 == circuit_2 + + @pytest.mark.parametrize("circuit_framework", CIRCUIT_FRAMEWORKS) + def test_eq_invalid_type( + self, + circuit_framework: Type[Circuit] + ) -> None: + """ Test the `__eq__` dunder method failure with invalid type. + + Parameters + ---------- + `circuit_framework`: type[quick.circuit.Circuit] + The circuit framework to test. + """ + circuit = circuit_framework(2) + + with pytest.raises(TypeError): + circuit == "circuit" # type: ignore + + @pytest.mark.parametrize("circuit_framework", CIRCUIT_FRAMEWORKS) + def test_is_equivalent( + self, + circuit_framework: Type[Circuit] + ) -> None: + """ Test the `is_equivalent` method. + + Parameters + ---------- + `circuit_framework`: type[quick.circuit.Circuit] + The circuit framework to test. + """ + # Define the circuits + circuit_1 = circuit_framework(3) + circuit_2 = circuit_framework(3) + + # Define the GHZ state + circuit_1.H(0) + circuit_1.CX(0, 1) + circuit_1.CX(0, 2) + + circuit_2.H(0) + circuit_2.CX(0, 2) + circuit_2.CX(0, 1) + + # Test the equivalence of the circuits + assert circuit_1.is_equivalent(circuit_2) + assert not circuit_1.is_equivalent(circuit_2, check_dag=True) + @pytest.mark.parametrize("circuit_framework", CIRCUIT_FRAMEWORKS) def test_len( self,