1
- from itertools import repeat , starmap
2
- from beartype .typing import Optional , List
1
+ from itertools import repeat
2
+ from beartype .typing import Optional , List , Dict , Set
3
3
from bloqade .builder .typing import ParamType
4
4
from bloqade .builder .base import Builder
5
5
from bloqade .builder .pragmas import Parallelizable , AddArgs , BatchAssignable
9
9
import numpy as np
10
10
11
11
12
- def cast_scalar_param (value : ParamType , name : str ) -> Decimal :
13
- if isinstance (value , (Real , Decimal )):
14
- return Decimal (str (value ))
12
+ class CastParams :
13
+ def __init__ (self , n_sites : int , scalar_vars : Set [str ], vector_vars : Set [str ]):
14
+ self .n_sites = n_sites
15
+ self .scalar_vars = scalar_vars
16
+ self .vector_vars = vector_vars
15
17
16
- raise TypeError (
17
- f"assign parameter '{ name } ' must be a real number, "
18
- f"found type: { type (value )} "
19
- )
18
+ def cast_scalar_param (self , value : ParamType , name : str ) -> Decimal :
19
+ if isinstance (value , (Real , Decimal )):
20
+ return Decimal (str (value ))
20
21
22
+ raise TypeError (
23
+ f"assign parameter '{ name } ' must be a real number, "
24
+ f"found type: { type (value )} "
25
+ )
21
26
22
- def cast_vector_param (value : List [ParamType ], name : str ) -> List [Decimal ]:
23
- if isinstance (value , np .ndarray ):
24
- value = value .tolist ()
27
+ def cast_vector_param (
28
+ self ,
29
+ value : List [ParamType ],
30
+ name : str ,
31
+ ) -> List [Decimal ]:
32
+ if isinstance (value , np .ndarray ):
33
+ value = value .tolist ()
25
34
26
- if isinstance (value , (list , tuple )):
27
- return list (starmap (cast_scalar_param , zip (value , repeat (name ))))
28
-
29
- raise TypeError (
30
- f"assign parameter '{ name } ' must be a list of real numbers, "
31
- f"found type: { type (value )} "
32
- )
33
-
34
-
35
- def cast_batch_scalar_param (value : List [ParamType ], name : str ) -> List [Decimal ]:
36
- if isinstance (value , np .ndarray ):
37
- value = value .tolist ()
38
-
39
- if isinstance (value , (list , tuple )):
40
- return list (starmap (cast_scalar_param , zip (value , repeat (name ))))
41
-
42
- raise TypeError (
43
- f"batch_assign parameter '{ name } ' must be a list of real numbers, "
44
- f"found type: { type (value )} "
45
- )
35
+ if isinstance (value , (list , tuple )):
36
+ if len (value ) != self .n_sites :
37
+ raise ValueError (
38
+ f"assign parameter '{ name } ' must be a list of length "
39
+ f"{ self .n_sites } , found length: { len (value )} "
40
+ )
41
+ return list (map (self .cast_scalar_param , value , repeat (name , len (value ))))
46
42
43
+ raise TypeError (
44
+ f"assign parameter '{ name } ' must be a list of real numbers, "
45
+ f"found type: { type (value )} "
46
+ )
47
47
48
- def cast_batch_vector_param (value : List [ParamType ], name : str ) -> List [List [Decimal ]]:
49
- if isinstance (value , (list , tuple )):
50
- return list (starmap (cast_vector_param , zip (value , repeat (name ))))
48
+ def cast_params (self , params : Dict [str , ParamType ]) -> Dict [str , ParamType ]:
49
+ checked_params = {}
50
+ for name , value in params .items ():
51
+ if name not in self .scalar_vars and name not in self .vector_vars :
52
+ raise ValueError (
53
+ f"assign parameter '{ name } ' is not found in analog circuit."
54
+ )
55
+ if name in self .vector_vars :
56
+ checked_params [name ] = self .cast_vector_param (value , name )
57
+ else :
58
+ checked_params [name ] = self .cast_scalar_param (value , name )
51
59
52
- raise TypeError (
53
- f"batch_assign parameter '{ name } ' must be a list of lists of real numbers, "
54
- f"found type: { type (value )} "
55
- )
60
+ return checked_params
56
61
57
62
58
63
class AssignBase (Builder ):
@@ -62,49 +67,64 @@ class AssignBase(Builder):
62
67
class Assign (BatchAssignable , AddArgs , Parallelizable , BackendRoute , AssignBase ):
63
68
__match_args__ = ("_assignments" , "__parent__" )
64
69
65
- def __init__ (self , parent : Optional [Builder ] = None , ** assignments ) -> None :
70
+ def __init__ (
71
+ self , assignments : Dict [str , ParamType ], parent : Optional [Builder ] = None
72
+ ) -> None :
66
73
from bloqade .ir .analysis .scan_variables import ScanVariablesAnalogCircuit
67
74
68
75
super ().__init__ (parent )
69
76
70
77
circuit = self .parse_circuit ()
71
- vars = ScanVariablesAnalogCircuit ().emit (circuit )
78
+ variables = ScanVariablesAnalogCircuit ().emit (circuit )
72
79
73
- self ._assignments = {}
74
- for name , value in assignments .items ():
75
- if name not in vars .scalar_vars and name not in vars .vector_vars :
76
- raise ValueError (
77
- f"batch_assign parameter '{ name } ' is not found in analog circuit."
78
- )
79
- if name in vars .vector_vars :
80
- self ._assignments [name ] = cast_vector_param (value , name )
81
- else :
82
- self ._assignments [name ] = cast_scalar_param (value , name )
80
+ self ._static_params = CastParams (
81
+ circuit .register .n_sites , variables .scalar_vars , variables .vector_vars
82
+ ).cast_params (assignments )
83
83
84
84
85
85
class BatchAssign (AddArgs , Parallelizable , BackendRoute , AssignBase ):
86
86
__match_args__ = ("_assignments" , "__parent__" )
87
87
88
- def __init__ (self , parent : Optional [Builder ] = None , ** assignments ) -> None :
88
+ def __init__ (
89
+ self , assignments : Dict [str , ParamType ], parent : Optional [Builder ] = None
90
+ ) -> None :
89
91
from bloqade .ir .analysis .scan_variables import ScanVariablesAnalogCircuit
90
92
91
93
super ().__init__ (parent )
92
94
93
95
circuit = self .parse_circuit ()
94
- vars = ScanVariablesAnalogCircuit ().emit (circuit )
95
-
96
- self ._assignments = {}
97
- for name , values in assignments .items ():
98
- if name not in vars .scalar_vars and name not in vars .vector_vars :
99
- raise ValueError (
100
- f"batch_assign parameter '{ name } ' is not found in analog circuit."
101
- )
102
- if name in vars .vector_vars :
103
- self ._assignments [name ] = cast_batch_vector_param (values , name )
104
- else :
105
- self ._assignments [name ] = cast_batch_scalar_param (values , name )
96
+ variables = ScanVariablesAnalogCircuit ().emit (circuit )
106
97
107
98
if not len (np .unique (list (map (len , assignments .values ())))) == 1 :
108
99
raise ValueError (
109
100
"all the assignment variables need to have same number of elements."
110
101
)
102
+
103
+ tuple_iterators = [
104
+ zip (repeat (name ), values ) for name , values in assignments .items ()
105
+ ]
106
+
107
+ caster = CastParams (
108
+ circuit .register .n_sites , variables .scalar_vars , variables .vector_vars
109
+ )
110
+
111
+ self ._batch_params = list (
112
+ map (caster .cast_params , map (dict , zip (* tuple_iterators )))
113
+ )
114
+
115
+
116
+ class ListAssign (AddArgs , Parallelizable , BackendRoute , AssignBase ):
117
+ def __init__ (
118
+ self , batch_params : List [Dict [str , ParamType ]], parent : Optional [Builder ] = None
119
+ ) -> None :
120
+ from bloqade .ir .analysis .scan_variables import ScanVariablesAnalogCircuit
121
+
122
+ super ().__init__ (parent )
123
+
124
+ circuit = self .parse_circuit ()
125
+ variables = ScanVariablesAnalogCircuit ().emit (circuit )
126
+ caster = CastParams (
127
+ circuit .register .n_sites , variables .scalar_vars , variables .vector_vars
128
+ )
129
+
130
+ self ._batch_params = list (map (caster .cast_params , batch_params ))
0 commit comments