Skip to content

Commit

Permalink
Fixing more bugs in emulator. (#614)
Browse files Browse the repository at this point in the history
* making detuning matrix elements negative.

* mising factor of 2 in rabi amplitude.

* add integration test comparing braket to bloqade emulator.

* removing extra sorting in KS test impl.

* adding check for bounds in locations.

* adding visitor to check if sequence has hyperfine coupling.

* fixing some bugs in local drives.

* making tests shorter, adding scipy KS test

* reduce shots counts

* reducing test time for emulator.

* adding integration tests for emulator ir codegen (#623)

* update `LocalBatch` annotation.
  • Loading branch information
weinbe58 authored Sep 25, 2023
1 parent 88be4f1 commit 4c72a23
Show file tree
Hide file tree
Showing 7 changed files with 413 additions and 36 deletions.
28 changes: 28 additions & 0 deletions src/bloqade/codegen/common/is_hyperfine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any
import bloqade.ir.analog_circuit as analog_circuit
import bloqade.ir.control.sequence as sequence
from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor


class IsHyperfineSequence(AnalogCircuitVisitor):
def __init__(self):
self.is_hyperfine = False

def visit_analog_circuit(self, ast: analog_circuit.AnalogCircuit) -> Any:
self.visit(ast.sequence)

def visit_append_sequence(self, ast: sequence.Append) -> Any:
list(map(self.visit, ast.sequences))

def visit_slice_sequence(self, ast: sequence.Slice) -> Any:
self.visit(ast.sequence)

def visit_named_sequence(self, ast: sequence.NamedSequence) -> Any:
self.visit(ast.sequence)

def visit_sequence(self, ast: sequence.Sequence) -> Any:
self.is_hyperfine = self.is_hyperfine or sequence.hyperfine in ast.pulses

def emit(self, ast: analog_circuit.AnalogCircuit) -> bool:
self.visit(ast)
return self.is_hyperfine
43 changes: 28 additions & 15 deletions src/bloqade/codegen/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import bloqade.ir.control.sequence as sequence
import bloqade.ir.control.pulse as pulse
import bloqade.ir.control.waveform as waveform
import bloqade.ir.scalar as scalar
import bloqade.ir.control.field as field
import bloqade.ir as ir

from bloqade.codegen.common.is_hyperfine import IsHyperfineSequence
from bloqade.emulate.ir.atom_type import ThreeLevelAtom, TwoLevelAtom
from bloqade.emulate.ir.emulator import (
DetuningOperatorData,
Expand Down Expand Up @@ -59,20 +59,18 @@ def __init__(
self.original_index = []

def visit_analog_circuit(self, ast: ir.AnalogCircuit):
self.n_atoms = ast.register.n_atoms

self.visit(ast.sequence)
self.visit(ast.register)
self.visit(ast.sequence)

def visit_register(self, ast: AtomArrangement) -> Any:
positions = []
for original_index, loc_info in enumerate(ast.enumerate()):
for org_index, loc_info in enumerate(ast.enumerate()):
if loc_info.filling == SiteFilling.filled:
position = tuple([pos(**self.assignments) for pos in loc_info.position])
positions.append(position)
self.original_index.append(original_index)
self.original_index.append(org_index)

if sequence.hyperfine in self.level_couplings:
if self.is_hyperfine:
self.register = Register(
ThreeLevelAtom,
positions,
Expand All @@ -91,7 +89,6 @@ def visit_sequence(self, ast: sequence.Sequence) -> None:
sequence.rydberg: LevelCoupling.Rydberg,
}
for level_coupling, sub_pulse in ast.pulses.items():
self.level_couplings.add(level_coupling)
self.visit(sub_pulse)
self.pulses[level_coupling_mapping[level_coupling]] = Fields(
detuning=self.detuning_terms,
Expand Down Expand Up @@ -144,9 +141,18 @@ def visit_assigned_run_time_vector(

def visit_scaled_locations(self, ast: ScaledLocations) -> Dict[int, Decimal]:
target_atoms = {}

for location in ast.value.keys():
if location.value >= self.n_sites or location.value < 0:
raise ValueError(
f"Location {location.value} is out of bounds for register with "
f"{self.n_sites} sites."
)

for new_index, original_index in enumerate(self.original_index):
value = ast.value.get(original_index, scalar.Literal(0))
target_atoms[new_index] = value(**self.assignments)
value = ast.value.get(field.Location(original_index))
if value is not None:
target_atoms[new_index] = value(**self.assignments)

return target_atoms

Expand Down Expand Up @@ -174,8 +180,9 @@ def visit_detuning(self, ast: Optional[Field]):
for atom in range(self.n_atoms):
wf = sum(
(
target_atom_dict[sm].get(atom, 0.0) * wf
target_atom_dict[sm][atom] * wf
for sm, wf in ast.value.items()
if atom in target_atom_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)
Expand Down Expand Up @@ -221,8 +228,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
for atom in range(self.n_atoms):
amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm].get(atom, 0.0) * wf
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)
Expand Down Expand Up @@ -274,16 +282,18 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
for atom in range(self.n_atoms):
phase_wf = sum(
(
phase_target_atoms_dict[sm].get(atom, 0.0) * wf
phase_target_atoms_dict[sm][atom] * wf
for sm, wf in phase.value.items()
if atom in phase_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)

amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm].get(atom, 0.0) * wf
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)
Expand All @@ -310,6 +320,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):

def emit(self, circuit: ir.AnalogCircuit) -> EmulatorProgram:
self.assignments = AssignmentScan(self.assignments).emit(circuit.sequence)
self.is_hyperfine = IsHyperfineSequence().emit(circuit)
self.n_atoms = circuit.register.n_atoms
self.n_sites = circuit.register.n_sites

self.visit(circuit)
return EmulatorProgram(
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/emulate/codegen/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def visit_fields(self, fields: Fields):

def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
if (self.register, detuning_data) in self.compile_cache.operator_cache:
return self.compile_cache.operator_cache[(self.space, detuning_data)]
return self.compile_cache.operator_cache[(self.register, detuning_data)]

diagonal = np.zeros(self.space.size, dtype=np.float64)
if self.space.atom_type == TwoLevelAtomType():
Expand All @@ -110,7 +110,7 @@ def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
state = ThreeLevelAtomType.State.Hyperfine

for atom_index, value in detuning_data.target_atoms.items():
diagonal[self.space.is_state_at(atom_index, state)] += float(value)
diagonal[self.space.is_state_at(atom_index, state)] -= float(value)

self.compile_cache.operator_cache[(self.register, detuning_data)] = diagonal
return diagonal
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/emulate/ir/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class RabiOperator:
phase: Optional[Callable[[float], float]] = None

def dot(self, register: NDArray, time: float):
amplitude = self.amplitude(time)
amplitude = self.amplitude(time) / 2
if self.phase is None:
return self.op.dot(register) * amplitude

Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/task/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from bloqade.task.quera import QuEraTask
from bloqade.task.braket import BraketTask
from bloqade.task.braket_simulator import BraketEmulatorTask
from bloqade.task.bloqade import BloqadeTask

from bloqade.builder.base import Builder

Expand Down Expand Up @@ -45,7 +46,7 @@ def json(self, **options) -> str:
@Serializer.register
class LocalBatch(Serializable):
source: Optional[Builder]
tasks: OrderedDict[int, BraketEmulatorTask]
tasks: OrderedDict[int, Union[BraketEmulatorTask, BloqadeTask]]
name: Optional[str] = None

def report(self) -> Report:
Expand Down
Loading

0 comments on commit 4c72a23

Please sign in to comment.