Skip to content

Commit

Permalink
fixing bugs with new batch_assign method.
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Sep 27, 2023
1 parent 6df9f35 commit fadefd4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
23 changes: 19 additions & 4 deletions src/bloqade/builder/assign.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import repeat
from beartype.typing import Optional, List, Dict, Set
from beartype.typing import Optional, List, Dict, Set, Sequence, Union
from bloqade.builder.typing import ParamType
from bloqade.builder.base import Builder
from bloqade.builder.pragmas import Parallelizable, AddArgs, BatchAssignable
Expand All @@ -26,7 +26,7 @@ def cast_scalar_param(self, value: ParamType, name: str) -> Decimal:

def cast_vector_param(
self,
value: List[ParamType],
value: Union[np.ndarray, List[ParamType]],
name: str,
) -> List[Decimal]:
if isinstance(value, np.ndarray):
Expand All @@ -47,6 +47,7 @@ def cast_vector_param(

def cast_params(self, params: Dict[str, ParamType]) -> Dict[str, ParamType]:
checked_params = {}

for name, value in params.items():
if name not in self.scalar_vars and name not in self.vector_vars:
raise ValueError(
Expand Down Expand Up @@ -86,7 +87,7 @@ class BatchAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
__match_args__ = ("_assignments", "__parent__")

def __init__(
self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None
self, assignments: Dict[str, List[ParamType]], parent: Optional[Builder] = None
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

Expand Down Expand Up @@ -115,7 +116,9 @@ def __init__(

class ListAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
def __init__(
self, batch_params: List[Dict[str, ParamType]], parent: Optional[Builder] = None
self,
batch_params: Sequence[Dict[str, ParamType]],
parent: Optional[Builder] = None,
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

Expand All @@ -127,4 +130,16 @@ def __init__(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
)

keys = set([])
for params in batch_params:
keys.update(params.keys())

for batch_num, params in enumerate(batch_params):
curr_keys = set(params.keys())
missing_keys = keys.difference(curr_keys)
if missing_keys:
raise ValueError(

Check warning on line 141 in src/bloqade/builder/assign.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/builder/assign.py#L141

Added line #L141 was not covered by tests
f"Batch {batch_num} missing key(s): {tuple(missing_keys)}."
)

self._batch_params = list(map(caster.cast_params, batch_params))
3 changes: 2 additions & 1 deletion src/bloqade/ir/location/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def __init__(
else:
self.location_list.append(LocationInfo(ele, True))

if location_list:
if self.location_list:
self.__n_atoms = sum(
1 for loc in self.location_list if loc.filling == SiteFilling.filled
)
self.__n_sites = len(self.location_list)
self.__n_vacant = self.__n_sites - self.__n_atoms
self.__n_dims = len(self.location_list[0].position)
else:
self.__n_sites = 0
self.__n_atoms = 0
self.__n_dims = None

Expand Down
12 changes: 11 additions & 1 deletion tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,21 @@ def test_assign_error():
with pytest.raises(TypeError):
start.rydberg.detuning.uniform.constant("c", "t").assign(c=np, t=10)

with pytest.raises(TypeError):
with pytest.raises(ValueError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(
c=[1, 2, np], t=[10]
)

with pytest.raises(TypeError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(
c=[1, 2, np], t=[10, 20, 30]
)

list_dict = [dict(c=1, t=10), dict(c=2, t=20), dict(c=np, t=30)]

with pytest.raises(TypeError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict)

with pytest.raises(TypeError):
(
start.add_position((0, 0))
Expand Down

0 comments on commit fadefd4

Please sign in to comment.