Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing more bugs in emulator. #614

Merged
merged 18 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/bloqade/codegen/common/is_hyperfine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any
import bloqade.ir.analog_circuit as analog_circuit
import bloqade.ir.control.sequence as sequence
from bloqade.ir.visitor.analog_circuit import AnalogCircuitVisitor


class IsHyperfineSequence(AnalogCircuitVisitor):
def __init__(self):
self.is_hyperfine = False

def visit_analog_circuit(self, ast: analog_circuit.AnalogCircuit) -> Any:
self.visit(ast.sequence)

def visit_append_sequence(self, ast: sequence.Append) -> Any:
list(map(self.visit, ast.sequences))
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved

def visit_slice_sequence(self, ast: sequence.Slice) -> Any:
self.visit(ast.sequence)

def visit_named_sequence(self, ast: sequence.NamedSequence) -> Any:
self.visit(ast.sequence)

def visit_sequence(self, ast: sequence.Sequence) -> Any:
self.is_hyperfine = self.is_hyperfine or sequence.hyperfine in ast.pulses

def emit(self, ast: analog_circuit.AnalogCircuit) -> bool:
self.visit(ast)
return self.is_hyperfine
31 changes: 20 additions & 11 deletions src/bloqade/emulate/codegen/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import bloqade.ir.control.sequence as sequence
import bloqade.ir.control.pulse as pulse
import bloqade.ir.control.waveform as waveform
import bloqade.ir.scalar as scalar
import bloqade.ir.control.field as field
import bloqade.ir as ir

from bloqade.codegen.common.is_hyperfine import IsHyperfineSequence
from bloqade.emulate.ir.atom_type import ThreeLevelAtom, TwoLevelAtom
from bloqade.emulate.ir.emulator import (
DetuningOperatorData,
Expand Down Expand Up @@ -59,20 +59,18 @@ def __init__(
self.original_index = []

def visit_analog_circuit(self, ast: ir.AnalogCircuit):
self.n_atoms = ast.register.n_atoms

self.visit(ast.sequence)
self.visit(ast.register)
self.visit(ast.sequence)

def visit_register(self, ast: AtomArrangement) -> Any:
positions = []
for original_index, loc_info in enumerate(ast.enumerate()):
for org_index, loc_info in enumerate(ast.enumerate()):
if loc_info.filling == SiteFilling.filled:
position = tuple([pos(**self.assignments) for pos in loc_info.position])
positions.append(position)
self.original_index.append(original_index)
self.original_index.append(org_index)

if sequence.hyperfine in self.level_couplings:
if self.is_hyperfine:
self.register = Register(
ThreeLevelAtom,
positions,
Expand All @@ -91,7 +89,6 @@ def visit_sequence(self, ast: sequence.Sequence) -> None:
sequence.rydberg: LevelCoupling.Rydberg,
}
for level_coupling, sub_pulse in ast.pulses.items():
self.level_couplings.add(level_coupling)
self.visit(sub_pulse)
self.pulses[level_coupling_mapping[level_coupling]] = Fields(
detuning=self.detuning_terms,
Expand Down Expand Up @@ -144,9 +141,18 @@ def visit_assigned_run_time_vector(

def visit_scaled_locations(self, ast: ScaledLocations) -> Dict[int, Decimal]:
target_atoms = {}

for location in ast.value.keys():
if location.value >= self.n_sites or location.value < 0:
raise ValueError(
f"Location {location.value} is out of bounds for register with "
f"{self.n_sites} sites."
)

for new_index, original_index in enumerate(self.original_index):
value = ast.value.get(original_index, scalar.Literal(0))
target_atoms[new_index] = value(**self.assignments)
value = ast.value.get(field.Location(original_index))
if value is not None:
target_atoms[new_index] = value(**self.assignments)

return target_atoms

Expand Down Expand Up @@ -310,6 +316,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):

def emit(self, circuit: ir.AnalogCircuit) -> EmulatorProgram:
self.assignments = AssignmentScan(self.assignments).emit(circuit.sequence)
self.is_hyperfine = IsHyperfineSequence().emit(circuit)
self.n_atoms = circuit.register.n_atoms
self.n_sites = circuit.register.n_sites

self.visit(circuit)
return EmulatorProgram(
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/emulate/codegen/rydberg_hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def visit_fields(self, fields: Fields):

def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
if (self.register, detuning_data) in self.compile_cache.operator_cache:
return self.compile_cache.operator_cache[(self.space, detuning_data)]
return self.compile_cache.operator_cache[(self.register, detuning_data)]

diagonal = np.zeros(self.space.size, dtype=np.float64)
if self.space.atom_type == TwoLevelAtomType():
Expand All @@ -110,7 +110,7 @@ def visit_detuning_operator_data(self, detuning_data: DetuningOperatorData):
state = ThreeLevelAtomType.State.Hyperfine

for atom_index, value in detuning_data.target_atoms.items():
diagonal[self.space.is_state_at(atom_index, state)] += float(value)
diagonal[self.space.is_state_at(atom_index, state)] -= float(value)

self.compile_cache.operator_cache[(self.register, detuning_data)] = diagonal
return diagonal
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/emulate/ir/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class RabiOperator:
phase: Optional[Callable[[float], float]] = None

def dot(self, register: NDArray, time: float):
amplitude = self.amplitude(time)
amplitude = self.amplitude(time) / 2
Roger-luo marked this conversation as resolved.
Show resolved Hide resolved
if self.phase is None:
return self.op.dot(register) * amplitude

Expand Down
93 changes: 85 additions & 8 deletions tests/test_python_emulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from bloqade import start, var
from bloqade import start, var, cast
from bloqade.atom_arrangement import Chain
from bloqade.serialize import dumps, loads
import numpy as np
from beartype.typing import Dict
from scipy.stats import ks_2samp


def test_integration_1():
Expand All @@ -13,7 +16,7 @@ def test_integration_1():
)
.amplitude.uniform.piecewise_linear([0.1, "ramp_time", 0.1], [0, 10, 10, 0])
.assign(ramp_time=3.0)
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 5).tolist())
.bloqade.python()
.run(10000, cache_matrices=True, blockade_radius=6.0, interaction_picture=True)
.report()
Expand All @@ -35,7 +38,7 @@ def test_integration_2():
[0.1, ramp_time / 2, ramp_time / 2, 0.1], [0, 0, np.pi, np.pi]
)
.assign(ramp_time=3.0)
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 5).tolist())
.bloqade.python()
.run(10000, cache_matrices=True, blockade_radius=6.0)
.report()
Expand All @@ -59,7 +62,7 @@ def test_integration_3():
.amplitude.var("rabi_mask")
.fn(lambda t: 4 * np.sin(3 * t), ramp_time + 0.2)
.assign(ramp_time=3.0, rabi_mask=[10.0, 0.1])
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 4).tolist())
.bloqade.python()
.run(10000, cache_matrices=True, blockade_radius=6.0)
.report()
Expand All @@ -82,7 +85,7 @@ def test_integration_4():
.amplitude.location(1)
.linear(0.0, 1.0, ramp_time + 0.2)
.assign(ramp_time=3.0, rabi_mask=[10.0, 0.1])
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 5).tolist())
.bloqade.python()
.run(10000, cache_matrices=True, blockade_radius=6.0)
.report()
Expand All @@ -103,7 +106,7 @@ def test_integration_5():
.phase.location(1)
.linear(0.0, 1.0, ramp_time + 0.2)
.assign(ramp_time=3.0)
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 5).tolist())
.bloqade.python()
.run(10000, cache_matrices=True, blockade_radius=6.0)
.report()
Expand All @@ -128,7 +131,7 @@ def test_integration_6():
.phase.location(1)
.linear(0.0, 1.0, ramp_time + 0.2)
.assign(ramp_time=3.0)
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 5).tolist())
.bloqade.python()
.run(10000, cache_matrices=True, blockade_radius=6.0)
.report()
Expand All @@ -151,11 +154,85 @@ def test_serialization():
.amplitude.location(1)
.linear(0.0, 1.0, ramp_time + 0.2)
.assign(ramp_time=3.0, rabi_mask=[10.0, 0.1])
.batch_assign(r=np.linspace(4, 10, 11).tolist())
.batch_assign(r=np.linspace(0.1, 4, 5).tolist())
.bloqade.python()
._compile(100)
)

obj_str = dumps(batch)
batch2 = loads(obj_str)
assert isinstance(batch2, type(batch))


def KS_test(
weinbe58 marked this conversation as resolved.
Show resolved Hide resolved
lhs_counts: Dict[str, int], rhs_counts: Dict[str, int], alpha: float = 0.05
) -> None:
lhs_samples = []
rhs_samples = []

for bitstring, count in lhs_counts.items():
lhs_samples += [int(bitstring, 2)] * count

for bitstring, count in rhs_counts.items():
rhs_samples += [int(bitstring, 2)] * count

result = ks_2samp(lhs_samples, rhs_samples, method="exact")

assert result.pvalue > alpha


def test_bloqade_against_braket():
weinbe58 marked this conversation as resolved.
Show resolved Hide resolved
np.random.seed(9123892)
durations = cast([0.1, 1.0, 0.1])

prog = (
Chain(5, lattice_spacing=6.1)
.rydberg.detuning.uniform.piecewise_linear(durations, [-20, -20, "d", "d"])
.amplitude.uniform.piecewise_linear(durations, [0, 15, 15, 0])
.phase.uniform.constant(0.3, sum(durations))
.batch_assign(d=[0, 10, 20, 30, 40])
)

nshots = 1000
a = prog.bloqade.python().run(nshots, cache_matrices=True).report().counts
b = prog.braket.local_emulator().run(nshots).report().counts

for lhs, rhs in zip(a, b):
KS_test(lhs, rhs)


def test_bloqade_against_braket_2():
np.random.seed(192839812)
durations = cast([0.1, 1.0, 0.1])
values = [0, 15, 15, 0]

prog_1 = (
Chain(5, lattice_spacing=6.1)
.rydberg.detuning.uniform.piecewise_linear(durations, [-20, -20, "d", "d"])
.amplitude.uniform.piecewise_linear(durations, values)
.batch_assign(d=[0, 10, 20, 30, 40])
)
prog_2 = (
Chain(5, lattice_spacing=6.1)
.rydberg.detuning.uniform.piecewise_linear(durations, [-20, -20, "d", "d"])
.amplitude.location(0)
.piecewise_linear(durations, values)
.amplitude.location(1)
.piecewise_linear(durations, values)
.amplitude.location(2)
.piecewise_linear(durations, values)
.amplitude.location(3)
.piecewise_linear(durations, values)
.amplitude.location(4)
.piecewise_linear(durations, values)
.phase.location(0)
.constant(0.0, sum(durations))
.batch_assign(d=[0, 10, 20, 30, 40])
)

nshots = 1000
a = prog_2.bloqade.python().run(nshots, cache_matrices=True).report().counts
b = prog_1.braket.local_emulator().run(nshots).report().counts

for lhs, rhs in zip(a, b):
KS_test(lhs, rhs)