Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing more bugs in emulator. #614

Merged
merged 18 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/bloqade/codegen/common/is_hyperfine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any
import bloqade.ir.analog_circuit as analog_circuit
import bloqade.ir.control.sequence as sequence
from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor


class IsHyperfineSequence(AnalogCircuitVisitor):
def __init__(self):
self.is_hyperfine = False

def visit_analog_circuit(self, ast: analog_circuit.AnalogCircuit) -> Any:
self.visit(ast.sequence)

def visit_append_sequence(self, ast: sequence.Append) -> Any:
list(map(self.visit, ast.sequences))
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved

def visit_slice_sequence(self, ast: sequence.Slice) -> Any:
self.visit(ast.sequence)

def visit_named_sequence(self, ast: sequence.NamedSequence) -> Any:
self.visit(ast.sequence)

def visit_sequence(self, ast: sequence.Sequence) -> Any:
self.is_hyperfine = self.is_hyperfine or sequence.hyperfine in ast.pulses

def emit(self, ast: analog_circuit.AnalogCircuit) -> bool:
self.visit(ast)
return self.is_hyperfine
43 changes: 28 additions & 15 deletions src/bloqade/codegen/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import bloqade.ir.control.sequence as sequence
import bloqade.ir.control.pulse as pulse
import bloqade.ir.control.waveform as waveform
import bloqade.ir.scalar as scalar
import bloqade.ir.control.field as field
import bloqade.ir as ir

from bloqade.codegen.common.is_hyperfine import IsHyperfineSequence
from bloqade.emulate.ir.atom_type import ThreeLevelAtom, TwoLevelAtom
from bloqade.emulate.ir.emulator import (
DetuningOperatorData,
Expand Down Expand Up @@ -59,20 +59,18 @@ def __init__(
self.original_index = []

def visit_analog_circuit(self, ast: ir.AnalogCircuit):
self.n_atoms = ast.register.n_atoms

self.visit(ast.sequence)
self.visit(ast.register)
self.visit(ast.sequence)

def visit_register(self, ast: AtomArrangement) -> Any:
positions = []
for original_index, loc_info in enumerate(ast.enumerate()):
for org_index, loc_info in enumerate(ast.enumerate()):
if loc_info.filling == SiteFilling.filled:
position = tuple([pos(**self.assignments) for pos in loc_info.position])
positions.append(position)
self.original_index.append(original_index)
self.original_index.append(org_index)

if sequence.hyperfine in self.level_couplings:
if self.is_hyperfine:
self.register = Register(
ThreeLevelAtom,
positions,
Expand All @@ -91,7 +89,6 @@ def visit_sequence(self, ast: sequence.Sequence) -> None:
sequence.rydberg: LevelCoupling.Rydberg,
}
for level_coupling, sub_pulse in ast.pulses.items():
self.level_couplings.add(level_coupling)
self.visit(sub_pulse)
self.pulses[level_coupling_mapping[level_coupling]] = Fields(
detuning=self.detuning_terms,
Expand Down Expand Up @@ -144,9 +141,18 @@ def visit_assigned_run_time_vector(

def visit_scaled_locations(self, ast: ScaledLocations) -> Dict[int, Decimal]:
target_atoms = {}

for location in ast.value.keys():
if location.value >= self.n_sites or location.value < 0:
raise ValueError(
f"Location {location.value} is out of bounds for register with "
f"{self.n_sites} sites."
)

for new_index, original_index in enumerate(self.original_index):
value = ast.value.get(original_index, scalar.Literal(0))
target_atoms[new_index] = value(**self.assignments)
value = ast.value.get(field.Location(original_index))
if value is not None:
target_atoms[new_index] = value(**self.assignments)

return target_atoms

Expand Down Expand Up @@ -174,8 +180,9 @@ def visit_detuning(self, ast: Optional[Field]):
for atom in range(self.n_atoms):
wf = sum(
(
target_atom_dict[sm].get(atom, 0.0) * wf
target_atom_dict[sm][atom] * wf
for sm, wf in ast.value.items()
if atom in target_atom_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)
Expand Down Expand Up @@ -221,8 +228,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
for atom in range(self.n_atoms):
amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm].get(atom, 0.0) * wf
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)
Expand Down Expand Up @@ -274,16 +282,18 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
for atom in range(self.n_atoms):
phase_wf = sum(
(
phase_target_atoms_dict[sm].get(atom, 0.0) * wf
phase_target_atoms_dict[sm][atom] * wf
for sm, wf in phase.value.items()
if atom in phase_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)

amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm].get(atom, 0.0) * wf
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
)
Expand All @@ -310,6 +320,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):

def emit(self, circuit: ir.AnalogCircuit) -> EmulatorProgram:
self.assignments = AssignmentScan(self.assignments).emit(circuit.sequence)
self.is_hyperfine = IsHyperfineSequence().emit(circuit)
self.n_atoms = circuit.register.n_atoms
self.n_sites = circuit.register.n_sites

self.visit(circuit)
return EmulatorProgram(
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/emulate/codegen/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def visit_fields(self, fields: Fields):

def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
if (self.register, detuning_data) in self.compile_cache.operator_cache:
return self.compile_cache.operator_cache[(self.space, detuning_data)]
return self.compile_cache.operator_cache[(self.register, detuning_data)]

diagonal = np.zeros(self.space.size, dtype=np.float64)
if self.space.atom_type == TwoLevelAtomType():
Expand All @@ -110,7 +110,7 @@ def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
state = ThreeLevelAtomType.State.Hyperfine

for atom_index, value in detuning_data.target_atoms.items():
diagonal[self.space.is_state_at(atom_index, state)] += float(value)
diagonal[self.space.is_state_at(atom_index, state)] -= float(value)

self.compile_cache.operator_cache[(self.register, detuning_data)] = diagonal
return diagonal
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/emulate/ir/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class RabiOperator:
phase: Optional[Callable[[float], float]] = None

def dot(self, register: NDArray, time: float):
amplitude = self.amplitude(time)
amplitude = self.amplitude(time) / 2
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved
if self.phase is None:
return self.op.dot(register) * amplitude

Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/task/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from bloqade.task.quera import QuEraTask
from bloqade.task.braket import BraketTask
from bloqade.task.braket_simulator import BraketEmulatorTask
from bloqade.task.bloqade import BloqadeTask

from bloqade.builder.base import Builder

Expand Down Expand Up @@ -45,7 +46,7 @@ def json(self, **options) -> str:
@Serializer.register
class LocalBatch(Serializable):
source: Optional[Builder]
tasks: OrderedDict[int, BraketEmulatorTask]
tasks: OrderedDict[int, Union[BraketEmulatorTask, BloqadeTask]]
name: Optional[str] = None

def report(self) -> Report:
Expand Down
Loading