diff --git a/pdm.lock b/pdm.lock index bba10e0e3..128043400 100644 --- a/pdm.lock +++ b/pdm.lock @@ -6,7 +6,7 @@ groups = ["default", "dev", "doc"] cross_platform = true static_urls = false lock_version = "4.3" -content_hash = "sha256:7c9349fd0b2c2d5448eb58bfba7701983bff94e26bf00510038cf0e1339c3cc2" +content_hash = "sha256:c2995ea995ac3393ea3e9e41bfe2ecbb41d84836e2f0c6aaa3dc81d93c727adc" [[package]] name = "amazon-braket-default-simulator" @@ -2227,6 +2227,20 @@ files = [ {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, ] +[[package]] +name = "plum-dispatch" +version = "2.2.2" +requires_python = ">=3.8" +summary = "Multiple dispatch in Python" +dependencies = [ + "beartype>=0.16.2", + "typing-extensions; python_version <= \"3.10\"", +] +files = [ + {file = "plum_dispatch-2.2.2-py3-none-any.whl", hash = "sha256:d7ee415bd166ffa90eaa4b24d7c9dc7ca6f8875750586001e7c9baff706223bd"}, + {file = "plum_dispatch-2.2.2.tar.gz", hash = "sha256:d5d180225c9fbf0277375bb558b649d97d0b651a91037bb7155cedbe9f52764b"}, +] + [[package]] name = "pre-commit" version = "3.4.0" diff --git a/pyproject.toml b/pyproject.toml index b9a7155ba..bd02ae646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "plotext>=5.2.8", "beartype>=0.15.0", "simplejson>=3.19.1", + "plum-dispatch>=2.2.2", ] requires-python = ">=3.9" readme = "README.md" diff --git a/src/bloqade/ir/location/transform.py b/src/bloqade/ir/location/transform.py index 41e7a2e4f..62a7d0604 100644 --- a/src/bloqade/ir/location/transform.py +++ b/src/bloqade/ir/location/transform.py @@ -3,6 +3,7 @@ from beartype.vale import Is from typing import Annotated from beartype import beartype +from plum import dispatch import numpy as np from bloqade.ir.scalar import cast @@ -46,7 +47,7 @@ def scale(self, scale: ScalarType): return ListOfLocations(location_list) - @beartype + @dispatch def _add_position( self, position: Tuple[ScalarType, ScalarType], filling: Optional[bool] = None ): @@ -61,8 +62,8 @@ def _add_position( return ListOfLocations(location_list) - @beartype - def _add_position_list( + @dispatch + def _add_position( # noqa: F811 self, position: List[Tuple[ScalarType, ScalarType]], filling: Optional[List[bool]] = None, @@ -86,11 +87,11 @@ def _add_position_list( return ListOfLocations(location_list) - @beartype - def _add_numpy_position( + @dispatch + def _add_position( # noqa: F811 self, position: PositionArray, filling: Optional[BoolArray] = None ): - return self._add_position_list( + return self.add_position( list(map(tuple, position.tolist())), filling.tolist() if filling is not None else None, ) @@ -98,11 +99,11 @@ def _add_numpy_position( def add_position( self, position: Union[ - Tuple[ScalarType, ScalarType], - List[Tuple[ScalarType, ScalarType]], PositionArray, + List[Tuple[ScalarType, ScalarType]], + Tuple[ScalarType, ScalarType], ], - filling: Optional[Union[bool, list[bool], BoolArray]] = None, + filling: Optional[Union[BoolArray, List[bool], bool]] = None, ) -> "ListOfLocations": """add a position or list of positions to existing atom arrangement. @@ -113,25 +114,15 @@ def add_position( position to add filling (bool | list[bool] | numpy.array with shape (n, ) | None, optional): - filling of the added position(s). Defaults to None. + filling of the added position(s). Defaults to None. if None, all + positions are filled. + Returns: ListOfLocations: new atom arrangement with added positions """ - if isinstance(position, tuple) and isinstance(filling, (bool, type(None))): - return self._add_position(position, filling) - elif isinstance(position, list) and isinstance(filling, (list, type(None))): - return self._add_position_list(position, filling) - elif isinstance(position, np.ndarray) and isinstance( - filling, (np.ndarray, type(None)) - ): - return self._add_numpy_position(position, filling) - else: - raise TypeError( - f"cannot interpret arguments, got {type(position)} " - f"for position and {type(filling)} for filling" - ) + return self._add_position(position, filling) @beartype def apply_defect_count( diff --git a/tests/test_builder.py b/tests/test_builder.py index a402e5bc2..1d32aafea 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -13,6 +13,7 @@ # import bloqade.builder.backend as builder_backend import bloqade.ir.routine.quera as quera import bloqade.ir.routine.braket as braket +from plum import NotFoundLookupError from bloqade.ir.control.waveform import instruction from bloqade.ir import rydberg, detuning, hyperfine, rabi @@ -67,9 +68,9 @@ def test_add_position_dispatch(): position = np.array([[1, 2], [3, 4]]) position_list = list(map(tuple, position.tolist())) - a = start.add_position(position) - b = start.add_position(position_list) - c = start.add_position(position_list[0]).add_position(position_list[1]) + a = start.add_position(position, np.array([True, False])) + b = start.add_position(position_list, [True, False]) + c = start.add_position(position_list[0]).add_position(position_list[1], False) assert a.location_list == b.location_list assert a.location_list == c.location_list @@ -77,6 +78,15 @@ def test_add_position_dispatch(): with pytest.raises(AssertionError): start.add_position(position_list, [True]) + with pytest.raises(NotFoundLookupError): + start.add_position(position_list, True) + + with pytest.raises(NotFoundLookupError): + start.add_position(position_list, np.array([True, True])) + + with pytest.raises(NotFoundLookupError): + start.add_position(position, [True, True]) + def test_piecewise_const(): prog = start.rydberg.detuning.uniform.piecewise_constant(