Skip to content

Commit

Permalink
Merge branch 'main' into matrix-generation-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 authored Sep 27, 2023
2 parents 281485a + 1e02f30 commit b9ec056
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 27 deletions.
16 changes: 15 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"plotext>=5.2.8",
"beartype>=0.15.0",
"simplejson>=3.19.1",
"plum-dispatch>=2.2.2",
]
requires-python = ">=3.9"
readme = "README.md"
Expand Down
37 changes: 14 additions & 23 deletions src/bloqade/ir/location/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from beartype.vale import Is
from typing import Annotated
from beartype import beartype
from plum import dispatch
import numpy as np
from bloqade.ir.scalar import cast

Expand Down Expand Up @@ -46,7 +47,7 @@ def scale(self, scale: ScalarType):

return ListOfLocations(location_list)

@beartype
@dispatch
def _add_position(
self, position: Tuple[ScalarType, ScalarType], filling: Optional[bool] = None
):
Expand All @@ -61,8 +62,8 @@ def _add_position(

return ListOfLocations(location_list)

@beartype
def _add_position_list(
@dispatch
def _add_position( # noqa: F811
self,
position: List[Tuple[ScalarType, ScalarType]],
filling: Optional[List[bool]] = None,
Expand All @@ -86,23 +87,23 @@ def _add_position_list(

return ListOfLocations(location_list)

@beartype
def _add_numpy_position(
@dispatch
def _add_position( # noqa: F811
self, position: PositionArray, filling: Optional[BoolArray] = None
):
return self._add_position_list(
return self.add_position(
list(map(tuple, position.tolist())),
filling.tolist() if filling is not None else None,
)

def add_position(
self,
position: Union[
Tuple[ScalarType, ScalarType],
List[Tuple[ScalarType, ScalarType]],
PositionArray,
List[Tuple[ScalarType, ScalarType]],
Tuple[ScalarType, ScalarType],
],
filling: Optional[Union[bool, list[bool], BoolArray]] = None,
filling: Optional[Union[BoolArray, List[bool], bool]] = None,
) -> "ListOfLocations":
"""add a position or list of positions to existing atom arrangement.
Expand All @@ -113,25 +114,15 @@ def add_position(
position to add
filling (bool | list[bool]
| numpy.array with shape (n, ) | None, optional):
filling of the added position(s). Defaults to None.
filling of the added position(s). Defaults to None. if None, all
positions are filled.
Returns:
ListOfLocations: new atom arrangement with added positions
"""
if isinstance(position, tuple) and isinstance(filling, (bool, type(None))):
return self._add_position(position, filling)
elif isinstance(position, list) and isinstance(filling, (list, type(None))):
return self._add_position_list(position, filling)
elif isinstance(position, np.ndarray) and isinstance(
filling, (np.ndarray, type(None))
):
return self._add_numpy_position(position, filling)
else:
raise TypeError(
f"cannot interpret arguments, got {type(position)} "
f"for position and {type(filling)} for filling"
)
return self._add_position(position, filling)

@beartype
def apply_defect_count(
Expand Down
16 changes: 13 additions & 3 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# import bloqade.builder.backend as builder_backend
import bloqade.ir.routine.quera as quera
import bloqade.ir.routine.braket as braket
from plum import NotFoundLookupError

from bloqade.ir.control.waveform import instruction
from bloqade.ir import rydberg, detuning, hyperfine, rabi
Expand Down Expand Up @@ -67,16 +68,25 @@ def test_add_position_dispatch():
position = np.array([[1, 2], [3, 4]])
position_list = list(map(tuple, position.tolist()))

a = start.add_position(position)
b = start.add_position(position_list)
c = start.add_position(position_list[0]).add_position(position_list[1])
a = start.add_position(position, np.array([True, False]))
b = start.add_position(position_list, [True, False])
c = start.add_position(position_list[0]).add_position(position_list[1], False)

assert a.location_list == b.location_list
assert a.location_list == c.location_list

with pytest.raises(AssertionError):
start.add_position(position_list, [True])

with pytest.raises(NotFoundLookupError):
start.add_position(position_list, True)

with pytest.raises(NotFoundLookupError):
start.add_position(position_list, np.array([True, True]))

with pytest.raises(NotFoundLookupError):
start.add_position(position, [True, True])


def test_piecewise_const():
prog = start.rydberg.detuning.uniform.piecewise_constant(
Expand Down

0 comments on commit b9ec056

Please sign in to comment.