From 0f690f3dee1422c34c3590b8fb7f0af9b911a252 Mon Sep 17 00:00:00 2001 From: Julien Gacon Date: Thu, 7 Nov 2024 13:34:10 +0100 Subject: [PATCH] remove Iterable instancechecks --- qiskit/circuit/library/n_local/n_local.py | 56 ++++++++++++---------- test/python/circuit/library/test_nlocal.py | 5 ++ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/qiskit/circuit/library/n_local/n_local.py b/qiskit/circuit/library/n_local/n_local.py index b8f8d317669..8c0b4d28508 100644 --- a/qiskit/circuit/library/n_local/n_local.py +++ b/qiskit/circuit/library/n_local/n_local.py @@ -1313,36 +1313,38 @@ def _normalize_entanglement( if isinstance(entanglement, str): return [entanglement] * num_entanglement_blocks - elif not callable(entanglement): - # handle edge cases when entanglement is set to an empty list - if len(entanglement) == 0: - return [[]] - - if isinstance(entanglement[0], str): - # if the entanglement is given as iterable we must make sure it matches - # the number of entanglement blocks - if len(entanglement) != num_entanglement_blocks: - raise QiskitError( - f"Number of block-entanglements ({len(entanglement)}) must match number of " - f"entanglement blocks ({num_entanglement_blocks})!" - ) + if callable(entanglement): + return lambda offset: _normalize_entanglement(entanglement(offset), num_entanglement_blocks) - return entanglement + # here, entanglement is an Iterable + if len(entanglement) == 0: + # handle edge cases when entanglement is set to an empty list + return [[]] - # normalize to list[BlockEntanglement] - if not isinstance(entanglement[0][0], Iterable): + # if the entanglement is Iterable[Iterable[int]], normalize to Iterable[Iterable[Iterable[int]]] + try: + # if users e.g. gave Iterable[int] this in invalid and will raise a TypeError + if isinstance(entanglement[0][0], (int, numpy.integer)): entanglement = [entanglement] + except TypeError as exc: + raise TypeError(f"Invalid entanglement type: {entanglement}.") from exc + + # ensure the number of block entanglements matches the number of blocks + if len(entanglement) != num_entanglement_blocks: + raise QiskitError( + f"Number of block-entanglements ({len(entanglement)}) must match number of " + f"entanglement blocks ({num_entanglement_blocks})!" + ) - if len(entanglement) != num_entanglement_blocks: - raise QiskitError( - f"Number of block-entanglements ({len(entanglement)}) must match number of " - f"entanglement blocks ({num_entanglement_blocks})!" - ) - - return [[tuple(connections) for connections in block] for block in entanglement] + # normalize the data: str remains, and Iterable[Iterable[int]] becomes list[tuple[int]] + normalized = [] + for block in entanglement: + if isinstance(block, str): + normalized.append(block) + else: + normalized.append([tuple(connections) for connections in block]) - # here, entanglement is a callable - return lambda offset: _normalize_entanglement(entanglement(offset), num_entanglement_blocks) + return normalized def _normalize_blocks( @@ -1350,7 +1352,9 @@ def _normalize_blocks( supported_gates: dict[str, Gate], overwrite_block_parameters: bool, ) -> list[Block]: - if not isinstance(blocks, Iterable) or isinstance(blocks, str): + # normalize the input into an iterable -- we add an extra check for a circuit as + # courtesy to the users, since the NLocal class used to accept circuits + if isinstance(blocks, (str, Gate, QuantumCircuit)): blocks = [blocks] normalized = [] diff --git a/test/python/circuit/library/test_nlocal.py b/test/python/circuit/library/test_nlocal.py index b4ab4fb085a..753feff490b 100644 --- a/test/python/circuit/library/test_nlocal.py +++ b/test/python/circuit/library/test_nlocal.py @@ -674,6 +674,11 @@ def test_entanglement_list_of_str(self): self.assertEqual(2, circuit.count_ops().get("cx", 0)) self.assertEqual(3, circuit.count_ops().get("cz", 0)) + def test_invalid_entanglement_list(self): + """Test passing an invalid list.""" + with self.assertRaises(TypeError): + _ = n_local(3, "h", "cx", entanglement=[0, 1]) # should be [(0, 1)] + def test_mismatching_entanglement_blocks_str(self): """Test an error is raised if the number of entanglements does not match the blocks.""" entanglement = ["full", "linear", "pairwise"]