12
12
import bloqade .ir .control .sequence as sequence
13
13
import bloqade .ir .control .pulse as pulse
14
14
import bloqade .ir .control .waveform as waveform
15
- import bloqade .ir .scalar as scalar
15
+ import bloqade .ir .control . field as field
16
16
import bloqade .ir as ir
17
-
17
+ from bloqade . codegen . common . is_hyperfine import IsHyperfineSequence
18
18
from bloqade .emulate .ir .atom_type import ThreeLevelAtom , TwoLevelAtom
19
19
from bloqade .emulate .ir .emulator import (
20
20
DetuningOperatorData ,
@@ -59,20 +59,18 @@ def __init__(
59
59
self .original_index = []
60
60
61
61
def visit_analog_circuit (self , ast : ir .AnalogCircuit ):
62
- self .n_atoms = ast .register .n_atoms
63
-
64
- self .visit (ast .sequence )
65
62
self .visit (ast .register )
63
+ self .visit (ast .sequence )
66
64
67
65
def visit_register (self , ast : AtomArrangement ) -> Any :
68
66
positions = []
69
- for original_index , loc_info in enumerate (ast .enumerate ()):
67
+ for org_index , loc_info in enumerate (ast .enumerate ()):
70
68
if loc_info .filling == SiteFilling .filled :
71
69
position = tuple ([pos (** self .assignments ) for pos in loc_info .position ])
72
70
positions .append (position )
73
- self .original_index .append (original_index )
71
+ self .original_index .append (org_index )
74
72
75
- if sequence . hyperfine in self .level_couplings :
73
+ if self .is_hyperfine :
76
74
self .register = Register (
77
75
ThreeLevelAtom ,
78
76
positions ,
@@ -91,7 +89,6 @@ def visit_sequence(self, ast: sequence.Sequence) -> None:
91
89
sequence .rydberg : LevelCoupling .Rydberg ,
92
90
}
93
91
for level_coupling , sub_pulse in ast .pulses .items ():
94
- self .level_couplings .add (level_coupling )
95
92
self .visit (sub_pulse )
96
93
self .pulses [level_coupling_mapping [level_coupling ]] = Fields (
97
94
detuning = self .detuning_terms ,
@@ -144,9 +141,18 @@ def visit_assigned_run_time_vector(
144
141
145
142
def visit_scaled_locations (self , ast : ScaledLocations ) -> Dict [int , Decimal ]:
146
143
target_atoms = {}
144
+
145
+ for location in ast .value .keys ():
146
+ if location .value >= self .n_sites or location .value < 0 :
147
+ raise ValueError (
148
+ f"Location { location .value } is out of bounds for register with "
149
+ f"{ self .n_sites } sites."
150
+ )
151
+
147
152
for new_index , original_index in enumerate (self .original_index ):
148
- value = ast .value .get (original_index , scalar .Literal (0 ))
149
- target_atoms [new_index ] = value (** self .assignments )
153
+ value = ast .value .get (field .Location (original_index ))
154
+ if value is not None :
155
+ target_atoms [new_index ] = value (** self .assignments )
150
156
151
157
return target_atoms
152
158
@@ -174,8 +180,9 @@ def visit_detuning(self, ast: Optional[Field]):
174
180
for atom in range (self .n_atoms ):
175
181
wf = sum (
176
182
(
177
- target_atom_dict [sm ]. get ( atom , 0.0 ) * wf
183
+ target_atom_dict [sm ][ atom ] * wf
178
184
for sm , wf in ast .value .items ()
185
+ if atom in target_atom_dict [sm ]
179
186
),
180
187
start = waveform .Constant (0.0 , 0.0 ),
181
188
)
@@ -221,8 +228,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
221
228
for atom in range (self .n_atoms ):
222
229
amplitude_wf = sum (
223
230
(
224
- amplitude_target_atoms_dict [sm ]. get ( atom , 0.0 ) * wf
231
+ amplitude_target_atoms_dict [sm ][ atom ] * wf
225
232
for sm , wf in amplitude .value .items ()
233
+ if atom in amplitude_target_atoms_dict [sm ]
226
234
),
227
235
start = waveform .Constant (0.0 , 0.0 ),
228
236
)
@@ -274,16 +282,18 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
274
282
for atom in range (self .n_atoms ):
275
283
phase_wf = sum (
276
284
(
277
- phase_target_atoms_dict [sm ]. get ( atom , 0.0 ) * wf
285
+ phase_target_atoms_dict [sm ][ atom ] * wf
278
286
for sm , wf in phase .value .items ()
287
+ if atom in phase_target_atoms_dict [sm ]
279
288
),
280
289
start = waveform .Constant (0.0 , 0.0 ),
281
290
)
282
291
283
292
amplitude_wf = sum (
284
293
(
285
- amplitude_target_atoms_dict [sm ]. get ( atom , 0.0 ) * wf
294
+ amplitude_target_atoms_dict [sm ][ atom ] * wf
286
295
for sm , wf in amplitude .value .items ()
296
+ if atom in amplitude_target_atoms_dict [sm ]
287
297
),
288
298
start = waveform .Constant (0.0 , 0.0 ),
289
299
)
@@ -310,6 +320,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
310
320
311
321
def emit (self , circuit : ir .AnalogCircuit ) -> EmulatorProgram :
312
322
self .assignments = AssignmentScan (self .assignments ).emit (circuit .sequence )
323
+ self .is_hyperfine = IsHyperfineSequence ().emit (circuit )
324
+ self .n_atoms = circuit .register .n_atoms
325
+ self .n_sites = circuit .register .n_sites
313
326
314
327
self .visit (circuit )
315
328
return EmulatorProgram (
0 commit comments