Skip to content

Commit

Permalink
remove Iterable instancechecks
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryoris committed Nov 7, 2024
1 parent 6950476 commit 0f690f3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
56 changes: 30 additions & 26 deletions qiskit/circuit/library/n_local/n_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,44 +1313,48 @@ def _normalize_entanglement(
if isinstance(entanglement, str):
return [entanglement] * num_entanglement_blocks

elif not callable(entanglement):
# handle edge cases when entanglement is set to an empty list
if len(entanglement) == 0:
return [[]]

if isinstance(entanglement[0], str):
# if the entanglement is given as iterable we must make sure it matches
# the number of entanglement blocks
if len(entanglement) != num_entanglement_blocks:
raise QiskitError(
f"Number of block-entanglements ({len(entanglement)}) must match number of "
f"entanglement blocks ({num_entanglement_blocks})!"
)
if callable(entanglement):
return lambda offset: _normalize_entanglement(entanglement(offset), num_entanglement_blocks)

return entanglement
# here, entanglement is an Iterable
if len(entanglement) == 0:
# handle edge cases when entanglement is set to an empty list
return [[]]

# normalize to list[BlockEntanglement]
if not isinstance(entanglement[0][0], Iterable):
# if the entanglement is Iterable[Iterable[int]], normalize to Iterable[Iterable[Iterable[int]]]
try:
# if users e.g. gave Iterable[int] this in invalid and will raise a TypeError
if isinstance(entanglement[0][0], (int, numpy.integer)):
entanglement = [entanglement]
except TypeError as exc:
raise TypeError(f"Invalid entanglement type: {entanglement}.") from exc

# ensure the number of block entanglements matches the number of blocks
if len(entanglement) != num_entanglement_blocks:
raise QiskitError(
f"Number of block-entanglements ({len(entanglement)}) must match number of "
f"entanglement blocks ({num_entanglement_blocks})!"
)

if len(entanglement) != num_entanglement_blocks:
raise QiskitError(
f"Number of block-entanglements ({len(entanglement)}) must match number of "
f"entanglement blocks ({num_entanglement_blocks})!"
)

return [[tuple(connections) for connections in block] for block in entanglement]
# normalize the data: str remains, and Iterable[Iterable[int]] becomes list[tuple[int]]
normalized = []
for block in entanglement:
if isinstance(block, str):
normalized.append(block)
else:
normalized.append([tuple(connections) for connections in block])

# here, entanglement is a callable
return lambda offset: _normalize_entanglement(entanglement(offset), num_entanglement_blocks)
return normalized


def _normalize_blocks(
blocks: str | Gate | Iterable[str | Gate],
supported_gates: dict[str, Gate],
overwrite_block_parameters: bool,
) -> list[Block]:
if not isinstance(blocks, Iterable) or isinstance(blocks, str):
# normalize the input into an iterable -- we add an extra check for a circuit as
# courtesy to the users, since the NLocal class used to accept circuits
if isinstance(blocks, (str, Gate, QuantumCircuit)):
blocks = [blocks]

normalized = []
Expand Down
5 changes: 5 additions & 0 deletions test/python/circuit/library/test_nlocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ def test_entanglement_list_of_str(self):
self.assertEqual(2, circuit.count_ops().get("cx", 0))
self.assertEqual(3, circuit.count_ops().get("cz", 0))

def test_invalid_entanglement_list(self):
"""Test passing an invalid list."""
with self.assertRaises(TypeError):
_ = n_local(3, "h", "cx", entanglement=[0, 1]) # should be [(0, 1)]

def test_mismatching_entanglement_blocks_str(self):
"""Test an error is raised if the number of entanglements does not match the blocks."""
entanglement = ["full", "linear", "pairwise"]
Expand Down

0 comments on commit 0f690f3

Please sign in to comment.