Skip to content

Commit

Permalink
Merge branch 'main' into fixing-IR-printing
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 authored Sep 28, 2023
2 parents 2b74cd2 + 20a9ef9 commit dc4727a
Show file tree
Hide file tree
Showing 6 changed files with 1,824 additions and 19 deletions.
50 changes: 45 additions & 5 deletions src/bloqade/codegen/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def emit(self, ast: waveform.Waveform) -> CompiledWaveform:

class EmulatorProgramCodeGen(AnalogCircuitVisitor):
def __init__(
self, assignments: Dict[str, Number] = {}, blockade_radius: Real = 0.0
self,
assignments: Dict[str, Number] = {},
blockade_radius: Real = 0.0,
use_hyperfine: bool = False,
):
self.blockade_radius = Decimal(str(blockade_radius))
self.assignments = assignments
Expand All @@ -57,6 +60,7 @@ def __init__(
self.pulses = {}
self.level_couplings = set()
self.original_index = []
self.is_hyperfine = use_hyperfine

def visit_analog_circuit(self, ast: ir.AnalogCircuit):
self.visit(ast.register)
Expand Down Expand Up @@ -129,6 +133,7 @@ def visit_run_time_vector(self, ast: RunTimeVector) -> Dict[int, Decimal]:
return {
new_index: Decimal(str(value[original_index]))
for new_index, original_index in enumerate(self.original_index)
if value[original_index] != 0
}

def visit_assigned_run_time_vector(
Expand All @@ -137,6 +142,7 @@ def visit_assigned_run_time_vector(
return {
new_index: Decimal(str(ast.value[original_index]))
for new_index, original_index in enumerate(self.original_index)
if ast.value[original_index] != 0
}

def visit_scaled_locations(self, ast: ScaledLocations) -> Dict[int, Decimal]:
Expand All @@ -151,7 +157,7 @@ def visit_scaled_locations(self, ast: ScaledLocations) -> Dict[int, Decimal]:

for new_index, original_index in enumerate(self.original_index):
value = ast.value.get(field.Location(original_index))
if value is not None:
if value is not None and value != 0:
target_atoms[new_index] = value(**self.assignments)

return target_atoms
Expand All @@ -178,6 +184,9 @@ def visit_detuning(self, ast: Optional[Field]):
target_atom_dict = {sm: self.visit(sm) for sm in ast.drives.keys()}

for atom in range(self.n_atoms):
if not any(atom in value for value in target_atom_dict.values()):
continue

wf = sum(
(
target_atom_dict[sm][atom] * wf
Expand All @@ -186,6 +195,7 @@ def visit_detuning(self, ast: Optional[Field]):
),
start=waveform.Constant(0.0, 0.0),
)

self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -212,10 +222,20 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)

target_atoms = self.visit(sm)

if len(target_atoms) == 0:
continue
elif len(target_atoms) == 1:
(scale,) = target_atoms.values()
if scale != 1:
wf = scale * wf

terms.append(
RabiTerm(
operator_data=RabiOperatorData(
target_atoms=self.visit(sm),
target_atoms=target_atoms,
operator_type=RabiOperatorType.RabiSymmetric,
),
amplitude=self.waveform_compiler.emit(wf),
Expand All @@ -226,6 +246,11 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
sm: self.visit(sm) for sm in amplitude.drives.keys()
}
for atom in range(self.n_atoms):
if not any(
atom in value for value in amplitude_target_atoms_dict.values()
):
continue

amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm][atom] * wf
Expand Down Expand Up @@ -262,10 +287,20 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)

target_atoms = self.visit(sm)

if len(target_atoms) == 0:
continue
elif len(target_atoms) == 1:
(scale,) = target_atoms.values()
if scale != 1:
wf = scale * wf

terms.append(
RabiTerm(
operator_data=RabiOperatorData(
target_atoms=self.visit(sm),
target_atoms=target_atoms,
operator_type=RabiOperatorType.RabiAsymmetric,
),
amplitude=self.waveform_compiler.emit(wf),
Expand All @@ -280,6 +315,11 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):

terms = []
for atom in range(self.n_atoms):
if not any(
atom in value for value in amplitude_target_atoms_dict.values()
):
continue

phase_wf = sum(
(
phase_target_atoms_dict[sm][atom] * wf
Expand Down Expand Up @@ -320,7 +360,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):

def emit(self, circuit: ir.AnalogCircuit) -> EmulatorProgram:
self.assignments = AssignmentScan(self.assignments).emit(circuit.sequence)
self.is_hyperfine = IsHyperfineSequence().emit(circuit)
self.is_hyperfine = IsHyperfineSequence().emit(circuit) or self.is_hyperfine
self.n_atoms = circuit.register.n_atoms
self.n_sites = circuit.register.n_sites

Expand Down
19 changes: 16 additions & 3 deletions src/bloqade/emulate/ir/atom_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ def swap_state_at(
mask_2 = self.is_state_at(configurations, index, state_2)
delta = state_2.value - state_1.value

output[mask_1] += delta * 3**index
output[mask_2] -= delta * 3**index
if delta < 0:
mask_1, mask_2 = mask_2, mask_1
delta = -delta

shift_value = np.array(delta * 3**index, dtype=configurations.dtype)

output[mask_1] += shift_value
output[mask_2] -= shift_value

np.logical_or(mask_1, mask_2, out=mask_1)

Expand All @@ -105,7 +111,14 @@ def transition_state_at(
output_configs = configurations[input_configs]

delta = to.value - fro.value
return (input_configs, output_configs + (delta * 3**index))

if delta < 0:
delta = -delta
output_configs -= np.array(delta * 3**index, dtype=output_configs.dtype)
else:
output_configs += np.array(delta * 3**index, dtype=output_configs.dtype)

return (input_configs, output_configs)


@dataclass(frozen=True)
Expand Down
30 changes: 21 additions & 9 deletions src/bloqade/emulate/ir/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .emulator import Register
from .atom_type import AtomType

MAX_PRINT_SIZE = 30


class SpaceType(str, Enum):
FullSpace = "full_space"
Expand Down Expand Up @@ -44,6 +46,11 @@ def create(cls, register: "Register"):
configurations = np.arange(Ns, dtype=np.min_scalar_type(Ns - 1))

if all(len(sub_list) == 0 for sub_list in check_atoms):
min_int_type = np.min_scalar_type(configurations[-1])
# defauly to 32 bit if smaller than 32 bit
config_type = np.result_type(min_int_type, np.uint32)
configurations = configurations.astype(config_type)

return Space(SpaceType.FullSpace, atom_type, sites, configurations)

for index_1, indices in enumerate(check_atoms):
Expand Down Expand Up @@ -96,18 +103,18 @@ def swap_state_at(self, index: int, state_1: int, state_2: int) -> NDArray:
row_indices, col_config = self.atom_type.swap_state_at(
self.configurations, index, state_1, state_2
)
if self.space_type == SpaceType.FullSpace:
if self.space_type is SpaceType.FullSpace:
return (row_indices, col_config)
else:
col_indices = np.searchsorted(self.configurations, col_config)

mask = col_indices < self.size
mask[mask] = col_config[mask] == self.configurations[col_indices[mask]]

if not np.all(mask):
if isinstance(row_indices, slice):
return mask, col_indices[mask]
else:
row_indices = np.argwhere(row_indices).ravel()
return row_indices[mask], col_indices[mask]
else:
return row_indices, col_indices
Expand All @@ -116,7 +123,7 @@ def transition_state_at(self, index: int, fro: int, to: int) -> NDArray:
row_indices, col_config = self.atom_type.transition_state_at(
self.configurations, index, fro, to
)
if self.space_type == SpaceType.FullSpace:
if self.space_type is SpaceType.FullSpace:
return (row_indices, col_config)
else:
col_indices = np.searchsorted(self.configurations, col_config)
Expand All @@ -126,7 +133,7 @@ def transition_state_at(self, index: int, fro: int, to: int) -> NDArray:
col_indices = col_indices[mask]
row_indices = row_indices[mask]

mask = col_config == self.configurations[col_indices]
mask = col_config[mask] == self.configurations[col_indices]
col_indices = col_indices[mask]
row_indices = row_indices[mask]

Expand All @@ -138,15 +145,17 @@ def fock_state_to_index(self, fock_state: str) -> int:
return state_int
else:
index = np.searchsorted(self.configurations, state_int)

if state_int != self.configurations[index]:
if index >= self.size or state_int != self.configurations[index]:
raise ValueError(
"state: {fock_state} not in rydberg blockade subspace."
)

return index

def index_to_fock_state(self, index: int) -> str:
if index < 0 or index >= self.size:
raise ValueError(f"index: {index} out of bounds.")

if self.space_type is SpaceType.FullSpace:
return self.atom_type.integer_to_string(index, self.n_atoms)
else:
Expand Down Expand Up @@ -185,20 +194,23 @@ def __str__(self):

n_digits = len(str(self.size - 1))
fmt = "{{index: >{}d}}. {{fock_state:s}}\n".format(n_digits)
if self.size < 50:
if self.size < MAX_PRINT_SIZE:
for index, state_int in enumerate(self.configurations):
fock_state = self.atom_type.integer_to_string(state_int, self.n_atoms)
output = output + fmt.format(index=index, fock_state=fock_state)

else:
for index, state_int in enumerate(self.configurations[:25]):
lower_index = MAX_PRINT_SIZE // 2 + (MAX_PRINT_SIZE % 2)
upper_index = self.size - MAX_PRINT_SIZE // 2

for index, state_int in enumerate(self.configurations[:lower_index]):
fock_state = self.atom_type.integer_to_string(state_int, self.n_atoms)
output = output + fmt.format(index=index, fock_state=fock_state)

output += (n_digits * " ") + "...\n"

for index, state_int in enumerate(
self.configurations[-25:], start=self.size - 25
self.configurations[upper_index:], start=self.size - MAX_PRINT_SIZE // 2
):
fock_state = self.atom_type.integer_to_string(state_int, self.n_atoms)
output = output + fmt.format(index=index, fock_state=fock_state)
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/emulate/sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dot(self, other):

return result

def to_csr(self):
def tocsr(self):
indptr = np.zeros(self.n_row + 1, dtype=np.int64)
indptr[1:][self.row_indices] = 1
np.cumsum(indptr, out=indptr)
Expand Down
Loading

0 comments on commit dc4727a

Please sign in to comment.