diff --git a/src/bloqade/builder/assign.py b/src/bloqade/builder/assign.py index 8e8e5abe5..af43e6945 100644 --- a/src/bloqade/builder/assign.py +++ b/src/bloqade/builder/assign.py @@ -1,5 +1,5 @@ -from itertools import repeat, starmap -from beartype.typing import Optional, List +from itertools import repeat +from beartype.typing import Optional, List, Dict, Set, Sequence, Union from bloqade.builder.typing import ParamType from bloqade.builder.base import Builder from bloqade.builder.pragmas import Parallelizable, AddArgs, BatchAssignable @@ -9,50 +9,56 @@ import numpy as np -def cast_scalar_param(value: ParamType, name: str) -> Decimal: - if isinstance(value, (Real, Decimal)): - return Decimal(str(value)) +class CastParams: + def __init__(self, n_sites: int, scalar_vars: Set[str], vector_vars: Set[str]): + self.n_sites = n_sites + self.scalar_vars = scalar_vars + self.vector_vars = vector_vars - raise TypeError( - f"assign parameter '{name}' must be a real number, " - f"found type: {type(value)}" - ) + def cast_scalar_param(self, value: ParamType, name: str) -> Decimal: + if isinstance(value, (Real, Decimal)): + return Decimal(str(value)) + raise TypeError( + f"assign parameter '{name}' must be a real number, " + f"found type: {type(value)}" + ) -def cast_vector_param(value: List[ParamType], name: str) -> List[Decimal]: - if isinstance(value, np.ndarray): - value = value.tolist() + def cast_vector_param( + self, + value: Union[np.ndarray, List[ParamType]], + name: str, + ) -> List[Decimal]: + if isinstance(value, np.ndarray): + value = value.tolist() - if isinstance(value, (list, tuple)): - return list(starmap(cast_scalar_param, zip(value, repeat(name)))) - - raise TypeError( - f"assign parameter '{name}' must be a list of real numbers, " - f"found type: {type(value)}" - ) - - -def cast_batch_scalar_param(value: List[ParamType], name: str) -> List[Decimal]: - if isinstance(value, np.ndarray): - value = value.tolist() - - if isinstance(value, (list, tuple)): - return list(starmap(cast_scalar_param, zip(value, repeat(name)))) + if isinstance(value, (list, tuple)): + if len(value) != self.n_sites: + raise ValueError( + f"assign parameter '{name}' must be a list of length " + f"{self.n_sites}, found length: {len(value)}" + ) + return list(map(self.cast_scalar_param, value, repeat(name, len(value)))) - raise TypeError( - f"batch_assign parameter '{name}' must be a list of real numbers, " - f"found type: {type(value)}" - ) + raise TypeError( + f"assign parameter '{name}' must be a list of real numbers, " + f"found type: {type(value)}" + ) + def cast_params(self, params: Dict[str, ParamType]) -> Dict[str, ParamType]: + checked_params = {} -def cast_batch_vector_param(value: List[ParamType], name: str) -> List[List[Decimal]]: - if isinstance(value, (list, tuple)): - return list(starmap(cast_vector_param, zip(value, repeat(name)))) + for name, value in params.items(): + if name not in self.scalar_vars and name not in self.vector_vars: + raise ValueError( + f"assign parameter '{name}' is not found in analog circuit." + ) + if name in self.vector_vars: + checked_params[name] = self.cast_vector_param(value, name) + else: + checked_params[name] = self.cast_scalar_param(value, name) - raise TypeError( - f"batch_assign parameter '{name}' must be a list of lists of real numbers, " - f"found type: {type(value)}" - ) + return checked_params class AssignBase(Builder): @@ -62,49 +68,78 @@ class AssignBase(Builder): class Assign(BatchAssignable, AddArgs, Parallelizable, BackendRoute, AssignBase): __match_args__ = ("_assignments", "__parent__") - def __init__(self, parent: Optional[Builder] = None, **assignments) -> None: + def __init__( + self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None + ) -> None: from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit super().__init__(parent) circuit = self.parse_circuit() - vars = ScanVariablesAnalogCircuit().emit(circuit) + variables = ScanVariablesAnalogCircuit().emit(circuit) - self._assignments = {} - for name, value in assignments.items(): - if name not in vars.scalar_vars and name not in vars.vector_vars: - raise ValueError( - f"batch_assign parameter '{name}' is not found in analog circuit." - ) - if name in vars.vector_vars: - self._assignments[name] = cast_vector_param(value, name) - else: - self._assignments[name] = cast_scalar_param(value, name) + self._static_params = CastParams( + circuit.register.n_sites, variables.scalar_vars, variables.vector_vars + ).cast_params(assignments) class BatchAssign(AddArgs, Parallelizable, BackendRoute, AssignBase): __match_args__ = ("_assignments", "__parent__") - def __init__(self, parent: Optional[Builder] = None, **assignments) -> None: + def __init__( + self, assignments: Dict[str, List[ParamType]], parent: Optional[Builder] = None + ) -> None: from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit super().__init__(parent) circuit = self.parse_circuit() - vars = ScanVariablesAnalogCircuit().emit(circuit) - - self._assignments = {} - for name, values in assignments.items(): - if name not in vars.scalar_vars and name not in vars.vector_vars: - raise ValueError( - f"batch_assign parameter '{name}' is not found in analog circuit." - ) - if name in vars.vector_vars: - self._assignments[name] = cast_batch_vector_param(values, name) - else: - self._assignments[name] = cast_batch_scalar_param(values, name) + variables = ScanVariablesAnalogCircuit().emit(circuit) if not len(np.unique(list(map(len, assignments.values())))) == 1: raise ValueError( "all the assignment variables need to have same number of elements." ) + + tuple_iterators = [ + zip(repeat(name), values) for name, values in assignments.items() + ] + + caster = CastParams( + circuit.register.n_sites, variables.scalar_vars, variables.vector_vars + ) + + self._batch_params = list( + map(caster.cast_params, map(dict, zip(*tuple_iterators))) + ) + + +class ListAssign(AddArgs, Parallelizable, BackendRoute, AssignBase): + def __init__( + self, + batch_params: Sequence[Dict[str, ParamType]], + parent: Optional[Builder] = None, + ) -> None: + from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit + + super().__init__(parent) + + circuit = self.parse_circuit() + variables = ScanVariablesAnalogCircuit().emit(circuit) + caster = CastParams( + circuit.register.n_sites, variables.scalar_vars, variables.vector_vars + ) + + keys = set([]) + for params in batch_params: + keys.update(params.keys()) + + for batch_num, params in enumerate(batch_params): + curr_keys = set(params.keys()) + missing_keys = keys.difference(curr_keys) + if missing_keys: + raise ValueError( + f"Batch {batch_num} missing key(s): {tuple(missing_keys)}." + ) + + self._batch_params = list(map(caster.cast_params, batch_params)) diff --git a/src/bloqade/builder/parse/builder.py b/src/bloqade/builder/parse/builder.py index 7b324ca03..71eff69cb 100644 --- a/src/bloqade/builder/parse/builder.py +++ b/src/bloqade/builder/parse/builder.py @@ -4,14 +4,12 @@ from bloqade.builder.field import Field, Detuning, RabiAmplitude, RabiPhase from bloqade.builder.spatial import SpatialModulation, Location, Uniform, Var, Scale from bloqade.builder.waveform import WaveformPrimitive, Slice, Record, Sample, Fn -from bloqade.builder.assign import Assign, BatchAssign +from bloqade.builder.assign import Assign, BatchAssign, ListAssign from bloqade.builder.args import Args from bloqade.builder.parallelize import Parallelize from bloqade.builder.parse.stream import BuilderNode, BuilderStream - import bloqade.ir as ir -from itertools import repeat -from typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set +from beartype.typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set if TYPE_CHECKING: from bloqade.ir.routine.params import ParamType @@ -186,6 +184,7 @@ def read_pragmas(self) -> None: pragma_types = ( Assign, BatchAssign, + ListAssign, Args, Parallelize, ) @@ -197,13 +196,9 @@ def read_pragmas(self) -> None: node = curr.node if isinstance(node, Assign): - self.static_params = dict(node._assignments) - elif isinstance(node, BatchAssign): - tuple_iterators = [ - zip(repeat(name), values) - for name, values in node._assignments.items() - ] - self.batch_params = list(map(dict, zip(*tuple_iterators))) + self.static_params = dict(node._static_params) + elif isinstance(node, BatchAssign) or isinstance(node, ListAssign): + self.batch_params = node._batch_params elif isinstance(node, Args): order = node._order diff --git a/src/bloqade/builder/pragmas.py b/src/bloqade/builder/pragmas.py index d3f9e547a..cf3139517 100644 --- a/src/bloqade/builder/pragmas.py +++ b/src/bloqade/builder/pragmas.py @@ -1,9 +1,9 @@ -from beartype.typing import List, Union, TYPE_CHECKING +from beartype.typing import List, Dict, Union, TYPE_CHECKING from bloqade.builder.typing import LiteralType, ParamType from bloqade.ir.scalar import Variable if TYPE_CHECKING: - from bloqade.builder.assign import Assign, BatchAssign + from bloqade.builder.assign import Assign, BatchAssign, ListAssign from bloqade.builder.parallelize import Parallelize from bloqade.builder.args import Args @@ -39,14 +39,24 @@ def assign(self, **assignments) -> "Assign": """ from bloqade.builder.assign import Assign - return Assign(parent=self, **assignments) + return Assign(assignments, parent=self) class BatchAssignable: - def batch_assign(self, **assignments: ParamType) -> "BatchAssign": - from bloqade.builder.assign import BatchAssign - - return BatchAssign(parent=self, **assignments) + def batch_assign( + self, + __batch_params: List[Dict[str, ParamType]] = [], + **assignments: List[ParamType], + ) -> Union["BatchAssign", "ListAssign"]: + from bloqade.builder.assign import BatchAssign, ListAssign + + if len(__batch_params) > 0 and assignments: + raise ValueError("batch_params and assignments cannot be used together.") + + if len(__batch_params) > 0: + return ListAssign(__batch_params, parent=self) + else: + return BatchAssign(assignments, parent=self) class Parallelizable: diff --git a/src/bloqade/ir/location/list.py b/src/bloqade/ir/location/list.py index 405ab6c52..313544ef8 100644 --- a/src/bloqade/ir/location/list.py +++ b/src/bloqade/ir/location/list.py @@ -19,7 +19,7 @@ def __init__( else: self.location_list.append(LocationInfo(ele, True)) - if location_list: + if self.location_list: self.__n_atoms = sum( 1 for loc in self.location_list if loc.filling == SiteFilling.filled ) @@ -27,6 +27,7 @@ def __init__( self.__n_vacant = self.__n_sites - self.__n_atoms self.__n_dims = len(self.location_list[0].position) else: + self.__n_sites = 0 self.__n_atoms = 0 self.__n_dims = None diff --git a/src/bloqade/ir/scalar.py b/src/bloqade/ir/scalar.py index 784c9ccb7..f2ce3c2e7 100644 --- a/src/bloqade/ir/scalar.py +++ b/src/bloqade/ir/scalar.py @@ -404,9 +404,14 @@ def _repr_pretty_(self, p, cycle): Printer(p).print(self, cycle) @validator("name", allow_reuse=True) - def validate_name(cls, v): - check_variable_name(v) - return v + def validate_name(cls, name): + check_variable_name(name) + if name in ["__batch_params"]: + raise ValidationError( + "Cannot use reserved name `__batch_params` for variable name" + ) + + return name @dataclass(frozen=True, repr=False) @@ -430,9 +435,14 @@ def print_node(self): return f"AssignedVariable: {self.name} = {self.value}" @validator("name", allow_reuse=True) - def validate_name(cls, v): - check_variable_name(v) - return v + def validate_name(cls, name): + check_variable_name(name) + if name in ["__batch_params"]: + raise ValidationError( + "Cannot use reserved name `__batch_params` for variable name" + ) + + return name @dataclass(frozen=True, repr=False) diff --git a/tests/test_builder.py b/tests/test_builder.py index 451d75e07..a402e5bc2 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -306,11 +306,61 @@ def test_assign_error(): with pytest.raises(TypeError): start.rydberg.detuning.uniform.constant("c", "t").assign(c=np, t=10) - with pytest.raises(TypeError): + with pytest.raises(ValueError): start.rydberg.detuning.uniform.constant("c", "t").batch_assign( c=[1, 2, np], t=[10] ) + with pytest.raises(TypeError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign( + c=[1, 2, np], t=[10, 20, 30] + ) + + list_dict = [dict(c=1, t=10), dict(c=2, t=20), dict(c=np, t=30)] + with pytest.raises(TypeError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict) + + list_dict = [dict(c=1, t=10), dict(c=2, t=20), dict(t=30)] + with pytest.raises(ValueError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict) + + list_dict = [dict(c=1, t=10, f=1), dict(c=2, t=20, f=2), dict(t=30, c=3, f=3)] + with pytest.raises(ValueError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict) + + dict_list = dict( + c=[1, 2, 3], t=[10, 20, 30], mask=[[1, 2, 3], [4, 5, 6], [7, 8, 9]] + ) + + with pytest.raises(ValueError): + start.rydberg.detuning.uniform.constant("c", "t").batch_assign(**dict_list) + + list_dict = [ + dict(c=1, t=10, mask=[1, 2, 3]), + dict(c=2, t=20, mask=[4, 5, 6]), + dict(t=30, c=3, mask=[7, 8, 9]), + ] + + with pytest.raises(ValueError): + start.add_position([(0, 0), (0, 6)]).rydberg.detuning.var("mask").constant( + "c", "t" + ).batch_assign(list_dict) + + # happy path is to have a list of dicts with the same keys + start.add_position([(0, 0), (0, 6), (0, 12)]).rydberg.detuning.var("mask").constant( + "c", "t" + ).batch_assign(list_dict) + + list_dict = [ + dict(c=1, t=10, mask=np.array([1, 2, 3])), + dict(c=2, t=20, mask=np.array([4, 5, 6])), + dict(t=30, c=3, mask=np.array([7, 8, 9])), + ] + with pytest.raises(ValueError): + start.add_position([(0, 0), (0, 6)]).rydberg.detuning.var("mask").constant( + "c", "t" + ).batch_assign(list_dict) + with pytest.raises(TypeError): ( start.add_position((0, 0)) diff --git a/tests/test_scalar.py b/tests/test_scalar.py index dd434ae19..30bed1bf0 100644 --- a/tests/test_scalar.py +++ b/tests/test_scalar.py @@ -1,3 +1,4 @@ +from pydantic import ValidationError from bloqade import cast, var import bloqade.ir.scalar as scalar import pytest @@ -19,6 +20,9 @@ def test_var(): with pytest.raises(ValueError): var("a*b") + with pytest.raises(ValidationError): + var("__batch_params") + Vv = var("a") vs = var(Vv)