Skip to content

Commit 4c72a23

Browse files
authored
Fixing more bugs in emulator. (#614)
* making detuning matrix elements negative. * mising factor of 2 in rabi amplitude. * add integration test comparing braket to bloqade emulator. * removing extra sorting in KS test impl. * adding check for bounds in locations. * adding visitor to check if sequence has hyperfine coupling. * fixing some bugs in local drives. * making tests shorter, adding scipy KS test * reduce shots counts * reducing test time for emulator. * adding integration tests for emulator ir codegen (#623) * update `LocalBatch` annotation.
1 parent 88be4f1 commit 4c72a23

File tree

7 files changed

+413
-36
lines changed

7 files changed

+413
-36
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Any
2+
import bloqade.ir.analog_circuit as analog_circuit
3+
import bloqade.ir.control.sequence as sequence
4+
from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor
5+
6+
7+
class IsHyperfineSequence(AnalogCircuitVisitor):
8+
def __init__(self):
9+
self.is_hyperfine = False
10+
11+
def visit_analog_circuit(self, ast: analog_circuit.AnalogCircuit) -> Any:
12+
self.visit(ast.sequence)
13+
14+
def visit_append_sequence(self, ast: sequence.Append) -> Any:
15+
list(map(self.visit, ast.sequences))
16+
17+
def visit_slice_sequence(self, ast: sequence.Slice) -> Any:
18+
self.visit(ast.sequence)
19+
20+
def visit_named_sequence(self, ast: sequence.NamedSequence) -> Any:
21+
self.visit(ast.sequence)
22+
23+
def visit_sequence(self, ast: sequence.Sequence) -> Any:
24+
self.is_hyperfine = self.is_hyperfine or sequence.hyperfine in ast.pulses
25+
26+
def emit(self, ast: analog_circuit.AnalogCircuit) -> bool:
27+
self.visit(ast)
28+
return self.is_hyperfine

src/bloqade/codegen/emulator_ir.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import bloqade.ir.control.sequence as sequence
1313
import bloqade.ir.control.pulse as pulse
1414
import bloqade.ir.control.waveform as waveform
15-
import bloqade.ir.scalar as scalar
15+
import bloqade.ir.control.field as field
1616
import bloqade.ir as ir
17-
17+
from bloqade.codegen.common.is_hyperfine import IsHyperfineSequence
1818
from bloqade.emulate.ir.atom_type import ThreeLevelAtom, TwoLevelAtom
1919
from bloqade.emulate.ir.emulator import (
2020
DetuningOperatorData,
@@ -59,20 +59,18 @@ def __init__(
5959
self.original_index = []
6060

6161
def visit_analog_circuit(self, ast: ir.AnalogCircuit):
62-
self.n_atoms = ast.register.n_atoms
63-
64-
self.visit(ast.sequence)
6562
self.visit(ast.register)
63+
self.visit(ast.sequence)
6664

6765
def visit_register(self, ast: AtomArrangement) -> Any:
6866
positions = []
69-
for original_index, loc_info in enumerate(ast.enumerate()):
67+
for org_index, loc_info in enumerate(ast.enumerate()):
7068
if loc_info.filling == SiteFilling.filled:
7169
position = tuple([pos(**self.assignments) for pos in loc_info.position])
7270
positions.append(position)
73-
self.original_index.append(original_index)
71+
self.original_index.append(org_index)
7472

75-
if sequence.hyperfine in self.level_couplings:
73+
if self.is_hyperfine:
7674
self.register = Register(
7775
ThreeLevelAtom,
7876
positions,
@@ -91,7 +89,6 @@ def visit_sequence(self, ast: sequence.Sequence) -> None:
9189
sequence.rydberg: LevelCoupling.Rydberg,
9290
}
9391
for level_coupling, sub_pulse in ast.pulses.items():
94-
self.level_couplings.add(level_coupling)
9592
self.visit(sub_pulse)
9693
self.pulses[level_coupling_mapping[level_coupling]] = Fields(
9794
detuning=self.detuning_terms,
@@ -144,9 +141,18 @@ def visit_assigned_run_time_vector(
144141

145142
def visit_scaled_locations(self, ast: ScaledLocations) -> Dict[int, Decimal]:
146143
target_atoms = {}
144+
145+
for location in ast.value.keys():
146+
if location.value >= self.n_sites or location.value < 0:
147+
raise ValueError(
148+
f"Location {location.value} is out of bounds for register with "
149+
f"{self.n_sites} sites."
150+
)
151+
147152
for new_index, original_index in enumerate(self.original_index):
148-
value = ast.value.get(original_index, scalar.Literal(0))
149-
target_atoms[new_index] = value(**self.assignments)
153+
value = ast.value.get(field.Location(original_index))
154+
if value is not None:
155+
target_atoms[new_index] = value(**self.assignments)
150156

151157
return target_atoms
152158

@@ -174,8 +180,9 @@ def visit_detuning(self, ast: Optional[Field]):
174180
for atom in range(self.n_atoms):
175181
wf = sum(
176182
(
177-
target_atom_dict[sm].get(atom, 0.0) * wf
183+
target_atom_dict[sm][atom] * wf
178184
for sm, wf in ast.value.items()
185+
if atom in target_atom_dict[sm]
179186
),
180187
start=waveform.Constant(0.0, 0.0),
181188
)
@@ -221,8 +228,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
221228
for atom in range(self.n_atoms):
222229
amplitude_wf = sum(
223230
(
224-
amplitude_target_atoms_dict[sm].get(atom, 0.0) * wf
231+
amplitude_target_atoms_dict[sm][atom] * wf
225232
for sm, wf in amplitude.value.items()
233+
if atom in amplitude_target_atoms_dict[sm]
226234
),
227235
start=waveform.Constant(0.0, 0.0),
228236
)
@@ -274,16 +282,18 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
274282
for atom in range(self.n_atoms):
275283
phase_wf = sum(
276284
(
277-
phase_target_atoms_dict[sm].get(atom, 0.0) * wf
285+
phase_target_atoms_dict[sm][atom] * wf
278286
for sm, wf in phase.value.items()
287+
if atom in phase_target_atoms_dict[sm]
279288
),
280289
start=waveform.Constant(0.0, 0.0),
281290
)
282291

283292
amplitude_wf = sum(
284293
(
285-
amplitude_target_atoms_dict[sm].get(atom, 0.0) * wf
294+
amplitude_target_atoms_dict[sm][atom] * wf
286295
for sm, wf in amplitude.value.items()
296+
if atom in amplitude_target_atoms_dict[sm]
287297
),
288298
start=waveform.Constant(0.0, 0.0),
289299
)
@@ -310,6 +320,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
310320

311321
def emit(self, circuit: ir.AnalogCircuit) -> EmulatorProgram:
312322
self.assignments = AssignmentScan(self.assignments).emit(circuit.sequence)
323+
self.is_hyperfine = IsHyperfineSequence().emit(circuit)
324+
self.n_atoms = circuit.register.n_atoms
325+
self.n_sites = circuit.register.n_sites
313326

314327
self.visit(circuit)
315328
return EmulatorProgram(

src/bloqade/emulate/codegen/hamiltonian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def visit_fields(self, fields: Fields):
9898

9999
def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
100100
if (self.register, detuning_data) in self.compile_cache.operator_cache:
101-
return self.compile_cache.operator_cache[(self.space, detuning_data)]
101+
return self.compile_cache.operator_cache[(self.register, detuning_data)]
102102

103103
diagonal = np.zeros(self.space.size, dtype=np.float64)
104104
if self.space.atom_type == TwoLevelAtomType():
@@ -110,7 +110,7 @@ def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
110110
state = ThreeLevelAtomType.State.Hyperfine
111111

112112
for atom_index, value in detuning_data.target_atoms.items():
113-
diagonal[self.space.is_state_at(atom_index, state)] += float(value)
113+
diagonal[self.space.is_state_at(atom_index, state)] -= float(value)
114114

115115
self.compile_cache.operator_cache[(self.register, detuning_data)] = diagonal
116116
return diagonal

src/bloqade/emulate/ir/state_vector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class RabiOperator:
3434
phase: Optional[Callable[[float], float]] = None
3535

3636
def dot(self, register: NDArray, time: float):
37-
amplitude = self.amplitude(time)
37+
amplitude = self.amplitude(time) / 2
3838
if self.phase is None:
3939
return self.op.dot(register) * amplitude
4040

src/bloqade/task/batch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from bloqade.task.quera import QuEraTask
44
from bloqade.task.braket import BraketTask
55
from bloqade.task.braket_simulator import BraketEmulatorTask
6+
from bloqade.task.bloqade import BloqadeTask
67

78
from bloqade.builder.base import Builder
89

@@ -45,7 +46,7 @@ def json(self, **options) -> str:
4546
@Serializer.register
4647
class LocalBatch(Serializable):
4748
source: Optional[Builder]
48-
tasks: OrderedDict[int, BraketEmulatorTask]
49+
tasks: OrderedDict[int, Union[BraketEmulatorTask, BloqadeTask]]
4950
name: Optional[str] = None
5051

5152
def report(self) -> Report:

0 commit comments

Comments
 (0)