From fadefd4f085432d0f117129c805c0fdb7900f3e1 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Wed, 27 Sep 2023 11:04:04 -0400 Subject: [PATCH] fixing bugs with new `batch_assign` method. --- src/bloqade/builder/assign.py | 23 +++++++++++++++++++---- src/bloqade/ir/location/list.py | 3 ++- tests/test_builder.py | 12 +++++++++++- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/bloqade/builder/assign.py b/src/bloqade/builder/assign.py index 92741abfb..af43e6945 100644 --- a/src/bloqade/builder/assign.py +++ b/src/bloqade/builder/assign.py @@ -1,5 +1,5 @@ from itertools import repeat -from beartype.typing import Optional, List, Dict, Set +from beartype.typing import Optional, List, Dict, Set, Sequence, Union from bloqade.builder.typing import ParamType from bloqade.builder.base import Builder from bloqade.builder.pragmas import Parallelizable, AddArgs, BatchAssignable @@ -26,7 +26,7 @@ def cast_scalar_param(self, value: ParamType, name: str) -> Decimal: def cast_vector_param( self, - value: List[ParamType], + value: Union[np.ndarray, List[ParamType]], name: str, ) -> List[Decimal]: if isinstance(value, np.ndarray): @@ -47,6 +47,7 @@ def cast_vector_param( def cast_params(self, params: Dict[str, ParamType]) -> Dict[str, ParamType]: checked_params = {} + for name, value in params.items(): if name not in self.scalar_vars and name not in self.vector_vars: raise ValueError( @@ -86,7 +87,7 @@ class BatchAssign(AddArgs, Parallelizable, BackendRoute, AssignBase): __match_args__ = ("_assignments", "__parent__") def __init__( - self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None + self, assignments: Dict[str, List[ParamType]], parent: Optional[Builder] = None ) -> None: from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit @@ -115,7 +116,9 @@ def __init__( class ListAssign(AddArgs, Parallelizable, BackendRoute, AssignBase): def __init__( - self, batch_params: List[Dict[str, ParamType]], parent: Optional[Builder] = None + self, + batch_params: Sequence[Dict[str, ParamType]], + parent: Optional[Builder] = None, ) -> None: from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit @@ -127,4 +130,16 @@ def __init__( circuit.register.n_sites, variables.scalar_vars, variables.vector_vars ) + keys = set([]) + for params in batch_params: + keys.update(params.keys()) + + for batch_num, params in enumerate(batch_params): + curr_keys = set(params.keys()) + missing_keys = keys.difference(curr_keys) + if missing_keys: + raise ValueError( + f"Batch {batch_num} missing key(s): {tuple(missing_keys)}." + ) + self._batch_params = list(map(caster.cast_params, batch_params)) diff --git a/src/bloqade/ir/location/list.py b/src/bloqade/ir/location/list.py index 405ab6c52..313544ef8 100644 --- a/src/bloqade/ir/location/list.py +++ b/src/bloqade/ir/location/list.py @@ -19,7 +19,7 @@ def __init__( else: self.location_list.append(LocationInfo(ele, True)) - if location_list: + if self.location_list: self.__n_atoms = sum( 1 for loc in self.location_list if loc.filling == SiteFilling.filled ) @@ -27,6 +27,7 @@ def __init__( self.__n_vacant = self.__n_sites - self.__n_atoms self.__n_dims = len(self.location_list[0].position) else: + self.__n_sites = 0 self.__n_atoms = 0 self.__n_dims = None diff --git a/tests/test_builder.py b/tests/test_builder.py index 451d75e07..66d51b827 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -306,11 +306,21 @@ def test_assign_error(): with pytest.raises(TypeError): start.rydberg.detuning.uniform.constant("c", "t").assign(c=np, t=10) - with pytest.raises(TypeError): + with pytest.raises(ValueError): start.rydberg.detuning.uniform.constant("c", "t").batch_assign( c=[1, 2, np], t=[10] ) + with pytest.raises(TypeError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign( + c=[1, 2, np], t=[10, 20, 30] + ) + + list_dict = [dict(c=1, t=10), dict(c=2, t=20), dict(c=np, t=30)] + + with pytest.raises(TypeError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict) + with pytest.raises(TypeError): ( start.add_position((0, 0))