Skip to content

Commit 6df9f35

Browse files
committed
debugging new casting.
1 parent 96f1b08 commit 6df9f35

File tree

3 files changed

+107
-82
lines changed

3 files changed

+107
-82
lines changed

src/bloqade/builder/assign.py

Lines changed: 84 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from itertools import repeat, starmap
2-
from beartype.typing import Optional, List
1+
from itertools import repeat
2+
from beartype.typing import Optional, List, Dict, Set
33
from bloqade.builder.typing import ParamType
44
from bloqade.builder.base import Builder
55
from bloqade.builder.pragmas import Parallelizable, AddArgs, BatchAssignable
@@ -9,50 +9,55 @@
99
import numpy as np
1010

1111

12-
def cast_scalar_param(value: ParamType, name: str) -> Decimal:
13-
if isinstance(value, (Real, Decimal)):
14-
return Decimal(str(value))
12+
class CastParams:
13+
def __init__(self, n_sites: int, scalar_vars: Set[str], vector_vars: Set[str]):
14+
self.n_sites = n_sites
15+
self.scalar_vars = scalar_vars
16+
self.vector_vars = vector_vars
1517

16-
raise TypeError(
17-
f"assign parameter '{name}' must be a real number, "
18-
f"found type: {type(value)}"
19-
)
18+
def cast_scalar_param(self, value: ParamType, name: str) -> Decimal:
19+
if isinstance(value, (Real, Decimal)):
20+
return Decimal(str(value))
2021

22+
raise TypeError(
23+
f"assign parameter '{name}' must be a real number, "
24+
f"found type: {type(value)}"
25+
)
2126

22-
def cast_vector_param(value: List[ParamType], name: str) -> List[Decimal]:
23-
if isinstance(value, np.ndarray):
24-
value = value.tolist()
27+
def cast_vector_param(
28+
self,
29+
value: List[ParamType],
30+
name: str,
31+
) -> List[Decimal]:
32+
if isinstance(value, np.ndarray):
33+
value = value.tolist()
2534

26-
if isinstance(value, (list, tuple)):
27-
return list(starmap(cast_scalar_param, zip(value, repeat(name))))
28-
29-
raise TypeError(
30-
f"assign parameter '{name}' must be a list of real numbers, "
31-
f"found type: {type(value)}"
32-
)
33-
34-
35-
def cast_batch_scalar_param(value: List[ParamType], name: str) -> List[Decimal]:
36-
if isinstance(value, np.ndarray):
37-
value = value.tolist()
38-
39-
if isinstance(value, (list, tuple)):
40-
return list(starmap(cast_scalar_param, zip(value, repeat(name))))
41-
42-
raise TypeError(
43-
f"batch_assign parameter '{name}' must be a list of real numbers, "
44-
f"found type: {type(value)}"
45-
)
35+
if isinstance(value, (list, tuple)):
36+
if len(value) != self.n_sites:
37+
raise ValueError(
38+
f"assign parameter '{name}' must be a list of length "
39+
f"{self.n_sites}, found length: {len(value)}"
40+
)
41+
return list(map(self.cast_scalar_param, value, repeat(name, len(value))))
4642

43+
raise TypeError(
44+
f"assign parameter '{name}' must be a list of real numbers, "
45+
f"found type: {type(value)}"
46+
)
4747

48-
def cast_batch_vector_param(value: List[ParamType], name: str) -> List[List[Decimal]]:
49-
if isinstance(value, (list, tuple)):
50-
return list(starmap(cast_vector_param, zip(value, repeat(name))))
48+
def cast_params(self, params: Dict[str, ParamType]) -> Dict[str, ParamType]:
49+
checked_params = {}
50+
for name, value in params.items():
51+
if name not in self.scalar_vars and name not in self.vector_vars:
52+
raise ValueError(
53+
f"assign parameter '{name}' is not found in analog circuit."
54+
)
55+
if name in self.vector_vars:
56+
checked_params[name] = self.cast_vector_param(value, name)
57+
else:
58+
checked_params[name] = self.cast_scalar_param(value, name)
5159

52-
raise TypeError(
53-
f"batch_assign parameter '{name}' must be a list of lists of real numbers, "
54-
f"found type: {type(value)}"
55-
)
60+
return checked_params
5661

5762

5863
class AssignBase(Builder):
@@ -62,49 +67,64 @@ class AssignBase(Builder):
6267
class Assign(BatchAssignable, AddArgs, Parallelizable, BackendRoute, AssignBase):
6368
__match_args__ = ("_assignments", "__parent__")
6469

65-
def __init__(self, parent: Optional[Builder] = None, **assignments) -> None:
70+
def __init__(
71+
self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None
72+
) -> None:
6673
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit
6774

6875
super().__init__(parent)
6976

7077
circuit = self.parse_circuit()
71-
vars = ScanVariablesAnalogCircuit().emit(circuit)
78+
variables = ScanVariablesAnalogCircuit().emit(circuit)
7279

73-
self._assignments = {}
74-
for name, value in assignments.items():
75-
if name not in vars.scalar_vars and name not in vars.vector_vars:
76-
raise ValueError(
77-
f"batch_assign parameter '{name}' is not found in analog circuit."
78-
)
79-
if name in vars.vector_vars:
80-
self._assignments[name] = cast_vector_param(value, name)
81-
else:
82-
self._assignments[name] = cast_scalar_param(value, name)
80+
self._static_params = CastParams(
81+
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
82+
).cast_params(assignments)
8383

8484

8585
class BatchAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
8686
__match_args__ = ("_assignments", "__parent__")
8787

88-
def __init__(self, parent: Optional[Builder] = None, **assignments) -> None:
88+
def __init__(
89+
self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None
90+
) -> None:
8991
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit
9092

9193
super().__init__(parent)
9294

9395
circuit = self.parse_circuit()
94-
vars = ScanVariablesAnalogCircuit().emit(circuit)
95-
96-
self._assignments = {}
97-
for name, values in assignments.items():
98-
if name not in vars.scalar_vars and name not in vars.vector_vars:
99-
raise ValueError(
100-
f"batch_assign parameter '{name}' is not found in analog circuit."
101-
)
102-
if name in vars.vector_vars:
103-
self._assignments[name] = cast_batch_vector_param(values, name)
104-
else:
105-
self._assignments[name] = cast_batch_scalar_param(values, name)
96+
variables = ScanVariablesAnalogCircuit().emit(circuit)
10697

10798
if not len(np.unique(list(map(len, assignments.values())))) == 1:
10899
raise ValueError(
109100
"all the assignment variables need to have same number of elements."
110101
)
102+
103+
tuple_iterators = [
104+
zip(repeat(name), values) for name, values in assignments.items()
105+
]
106+
107+
caster = CastParams(
108+
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
109+
)
110+
111+
self._batch_params = list(
112+
map(caster.cast_params, map(dict, zip(*tuple_iterators)))
113+
)
114+
115+
116+
class ListAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
117+
def __init__(
118+
self, batch_params: List[Dict[str, ParamType]], parent: Optional[Builder] = None
119+
) -> None:
120+
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit
121+
122+
super().__init__(parent)
123+
124+
circuit = self.parse_circuit()
125+
variables = ScanVariablesAnalogCircuit().emit(circuit)
126+
caster = CastParams(
127+
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
128+
)
129+
130+
self._batch_params = list(map(caster.cast_params, batch_params))

src/bloqade/builder/parse/builder.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
from bloqade.builder.field import Field, Detuning, RabiAmplitude, RabiPhase
55
from bloqade.builder.spatial import SpatialModulation, Location, Uniform, Var, Scale
66
from bloqade.builder.waveform import WaveformPrimitive, Slice, Record, Sample, Fn
7-
from bloqade.builder.assign import Assign, BatchAssign
7+
from bloqade.builder.assign import Assign, BatchAssign, ListAssign
88
from bloqade.builder.args import Args
99
from bloqade.builder.parallelize import Parallelize
1010
from bloqade.builder.parse.stream import BuilderNode, BuilderStream
11-
1211
import bloqade.ir as ir
13-
from itertools import repeat
14-
from typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set
12+
from beartype.typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set
1513

1614
if TYPE_CHECKING:
1715
from bloqade.ir.routine.params import ParamType
@@ -186,6 +184,7 @@ def read_pragmas(self) -> None:
186184
pragma_types = (
187185
Assign,
188186
BatchAssign,
187+
ListAssign,
189188
Args,
190189
Parallelize,
191190
)
@@ -197,13 +196,9 @@ def read_pragmas(self) -> None:
197196
node = curr.node
198197

199198
if isinstance(node, Assign):
200-
self.static_params = dict(node._assignments)
201-
elif isinstance(node, BatchAssign):
202-
tuple_iterators = [
203-
zip(repeat(name), values)
204-
for name, values in node._assignments.items()
205-
]
206-
self.batch_params = list(map(dict, zip(*tuple_iterators)))
199+
self.static_params = dict(node._static_params)
200+
elif isinstance(node, BatchAssign) or isinstance(node, ListAssign):
201+
self.batch_params = node._batch_params
207202
elif isinstance(node, Args):
208203
order = node._order
209204

src/bloqade/builder/pragmas.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from beartype.typing import List, Union, TYPE_CHECKING
1+
from beartype.typing import List, Dict, Union, TYPE_CHECKING
22
from bloqade.builder.typing import LiteralType, ParamType
33
from bloqade.ir.scalar import Variable
44

55
if TYPE_CHECKING:
6-
from bloqade.builder.assign import Assign, BatchAssign
6+
from bloqade.builder.assign import Assign, BatchAssign, ListAssign
77
from bloqade.builder.parallelize import Parallelize
88
from bloqade.builder.args import Args
99

@@ -39,14 +39,24 @@ def assign(self, **assignments) -> "Assign":
3939
"""
4040
from bloqade.builder.assign import Assign
4141

42-
return Assign(parent=self, **assignments)
42+
return Assign(assignments, parent=self)
4343

4444

4545
class BatchAssignable:
46-
def batch_assign(self, **assignments: ParamType) -> "BatchAssign":
47-
from bloqade.builder.assign import BatchAssign
48-
49-
return BatchAssign(parent=self, **assignments)
46+
def batch_assign(
47+
self,
48+
batch_params: List[Dict[str, ParamType]] = [],
49+
**assignments: List[ParamType],
50+
) -> Union["BatchAssign", "ListAssign"]:
51+
from bloqade.builder.assign import BatchAssign, ListAssign
52+
53+
if len(batch_params) > 0 and assignments:
54+
raise ValueError("batch_params and assignments cannot be used together.")
55+
56+
if len(batch_params) > 0:
57+
return ListAssign(batch_params, parent=self)
58+
else:
59+
return BatchAssign(assignments, parent=self)
5060

5161

5262
class Parallelizable:

0 commit comments

Comments
 (0)