Skip to content

Commit

Permalink
Fixing ir printing (#629)
Browse files Browse the repository at this point in the history
* making all IR pydantic. dataclasses

* removing some reprs

* removing old repr for waveforms.

* add drop-in node for complicated dict in IR.

* fixing printers for field.

* removing all __str__.

* renaming field.

* renaming field.

* renaming field.

* fixing pretty-printer for locations.

* fixing issues with renaming field attr.

* adding tests for lattice pprint.

* adding equality for `KeyValuePair` for testing

* fixing tests for fields printer

* fixing more tests.

* fixing tests

* fixing printing for AnalogCircuit.

* removing commented code.

* removing commented code.

* modify traits for builder.

* removing `KeyValuePair`

* fixing field test.

* fixing tests for new printing

* fixing emulator ir tests.

* moving pprint to `Routine`.

* cleaning up if statement

* removing `repr=False`

* adding dumb inline printing for scalars.

* adding unit tests for printer.
  • Loading branch information
weinbe58 authored Sep 28, 2023
1 parent 62a179e commit a60fb8e
Show file tree
Hide file tree
Showing 31 changed files with 590 additions and 645 deletions.
7 changes: 5 additions & 2 deletions src/bloqade/builder/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Optional, Union, List
from numbers import Real

from bloqade.builder.parse.trait import CompileJSON, Parse
from bloqade.builder.parse.trait import Parse, Show

ParamType = Union[Real, List[Real]]


class Builder(CompileJSON, Parse):
class Builder(Parse, Show):
__match_args__ = ("__parent__",)

def __init__(
self,
parent: Optional["Builder"] = None,
) -> None:
self.__parent__ = parent

def __str__(self):
return str(self.parse())
7 changes: 1 addition & 6 deletions src/bloqade/builder/parse/trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ def n_atoms(self: "Builder"):
else:
return register.n_atoms

def __repr__(self: "Builder"):
from .builder import Parser

analog_circ, metas = Parser().parse(self)

return repr(analog_circ) + "\n" + repr(metas)

class Show:
def show(self, batch_id: int = 0):
display_builder(self, batch_id)
4 changes: 2 additions & 2 deletions src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def visit_pulse(self, ast: pulse.Pulse) -> pulse.Pulse:
)

def visit_append_pulse(self, ast: pulse.Append) -> pulse.Append:
return pulse.Append(list(map(self.visit, ast.value)))
return pulse.Append(list(map(self.visit, ast.pulses)))

def visit_slice_pulse(self, ast: pulse.Slice) -> pulse.Slice:
return pulse.Slice(self.visit(ast.pulse), self.scalar_visitor(ast.interval))
Expand All @@ -268,7 +268,7 @@ def visit_named_pulse(self, ast: pulse.NamedPulse) -> Any:

def visit_field(self, ast: field.Field) -> field.Field:
return field.Field(
{self.visit(sm): self.visit(wf) for sm, wf in ast.value.items()}
{self.visit(sm): self.visit(wf) for sm, wf in ast.drives.items()}
)

def visit_uniform_modulation(self, ast: field.UniformModulation) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/codegen/common/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def visit_assigned_run_time_vector(
def visit_field(self, ast: field.Field) -> Any:
return {
"field": {
"value": [(self.visit(k), self.visit(v)) for k, v in ast.value.items()]
"value": [(self.visit(k), self.visit(v)) for k, v in ast.drives.items()]
}
}

Expand Down
34 changes: 17 additions & 17 deletions src/bloqade/codegen/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def visit_detuning(self, ast: Optional[Field]):

terms = []

if len(ast.value) <= self.n_atoms:
for sm, wf in ast.value.items():
if len(ast.drives) <= self.n_atoms:
for sm, wf in ast.drives.items():
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -181,7 +181,7 @@ def visit_detuning(self, ast: Optional[Field]):
)
)
else:
target_atom_dict = {sm: self.visit(sm) for sm in ast.value.keys()}
target_atom_dict = {sm: self.visit(sm) for sm in ast.drives.keys()}

for atom in range(self.n_atoms):
if not any(atom in value for value in target_atom_dict.values()):
Expand All @@ -190,7 +190,7 @@ def visit_detuning(self, ast: Optional[Field]):
wf = sum(
(
target_atom_dict[sm][atom] * wf
for sm, wf in ast.value.items()
for sm, wf in ast.drives.items()
if atom in target_atom_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand All @@ -217,8 +217,8 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
if amplitude is None:
return terms

if phase is None and len(amplitude.value) <= self.n_atoms:
for sm, wf in amplitude.value.items():
if phase is None and len(amplitude.drives) <= self.n_atoms:
for sm, wf in amplitude.drives.items():
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -243,7 +243,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
)
elif phase is None: # fully local real rabi fields
amplitude_target_atoms_dict = {
sm: self.visit(sm) for sm in amplitude.value.keys()
sm: self.visit(sm) for sm in amplitude.drives.keys()
}
for atom in range(self.n_atoms):
if not any(
Expand All @@ -254,7 +254,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
for sm, wf in amplitude.drives.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand All @@ -274,16 +274,16 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
)
)
elif (
len(phase.value) == 1
and UniformModulation() in phase.value
and len(amplitude.value) <= self.n_atoms
len(phase.drives) == 1
and UniformModulation() in phase.drives
and len(amplitude.drives) <= self.n_atoms
):
(phase_waveform,) = phase.value.values()
(phase_waveform,) = phase.drives.values()
rabi_phase = self.waveform_compiler.emit(phase_waveform)
self.duration = max(
float(phase_waveform.duration(**self.assignments)), self.duration
)
for sm, wf in amplitude.value.items():
for sm, wf in amplitude.drives.items():
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -308,9 +308,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
)
)
else:
phase_target_atoms_dict = {sm: self.visit(sm) for sm in phase.value.keys()}
phase_target_atoms_dict = {sm: self.visit(sm) for sm in phase.drives.keys()}
amplitude_target_atoms_dict = {
sm: self.visit(sm) for sm in amplitude.value.keys()
sm: self.visit(sm) for sm in amplitude.drives.keys()
}

terms = []
Expand All @@ -323,7 +323,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
phase_wf = sum(
(
phase_target_atoms_dict[sm][atom] * wf
for sm, wf in phase.value.items()
for sm, wf in phase.drives.items()
if atom in phase_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand All @@ -332,7 +332,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
for sm, wf in amplitude.drives.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand Down
24 changes: 12 additions & 12 deletions src/bloqade/codegen/hardware/quera.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def visit_assigned_run_time_vector(self, ast: field.AssignedRunTimeVector) -> An
self.post_visit_spatial_modulation(lattice_site_coefficients)

def visit_detuning(self, ast: field.Field) -> Any:
if len(ast.value) == 1 and field.Uniform in ast.value:
if len(ast.drives) == 1 and field.Uniform in ast.drives:
times, values = PiecewiseLinearCodeGen(self.assignments).visit(
ast.value[field.Uniform]
ast.drives[field.Uniform]
)

times = QuEraCodeGen.convert_time_to_SI_units(times)
Expand All @@ -429,8 +429,8 @@ def visit_detuning(self, ast: field.Field) -> Any:
self.detuning = task_spec.Detuning(
global_=task_spec.GlobalField(times=times, values=values)
)
elif len(ast.value) == 1:
((spatial_modulation, waveform),) = ast.value.items()
elif len(ast.drives) == 1:
((spatial_modulation, waveform),) = ast.drives.items()

times, values = PiecewiseLinearCodeGen(self.assignments).visit(waveform)

Expand All @@ -446,18 +446,18 @@ def visit_detuning(self, ast: field.Field) -> Any:
lattice_site_coefficients=self.lattice_site_coefficients,
),
)
elif len(ast.value) == 2 and field.Uniform in ast.value:
elif len(ast.drives) == 2 and field.Uniform in ast.drives:
# will only be two keys
for k in ast.value.keys():
for k in ast.drives.keys():
if k == field.Uniform:
global_times, global_values = PiecewiseLinearCodeGen(
self.assignments
).visit(ast.value[field.Uniform])
).visit(ast.drives[field.Uniform])
else: # can be field.RunTimeVector or field.ScaledLocations
spatial_modulation = k
local_times, local_values = PiecewiseLinearCodeGen(
self.assignments
).visit(ast.value[k])
).visit(ast.drives[k])

self.visit(spatial_modulation) # just visit the non-uniform locations

Expand All @@ -483,9 +483,9 @@ def visit_detuning(self, ast: field.Field) -> Any:
)

def visit_rabi_amplitude(self, ast: field.Field) -> Any:
if len(ast.value) == 1 and field.Uniform in ast.value:
if len(ast.drives) == 1 and field.Uniform in ast.drives:
times, values = PiecewiseLinearCodeGen(self.assignments).visit(
ast.value[field.Uniform]
ast.drives[field.Uniform]
)

times = QuEraCodeGen.convert_time_to_SI_units(times)
Expand All @@ -502,9 +502,9 @@ def visit_rabi_amplitude(self, ast: field.Field) -> Any:
)

def visit_rabi_phase(self, ast: field.Field) -> Any:
if len(ast.value) == 1 and field.Uniform in ast.value: # has to be global
if len(ast.drives) == 1 and field.Uniform in ast.drives: # has to be global
times, values = PiecewiseConstantCodeGen(self.assignments).visit(
ast.value[field.Uniform]
ast.drives[field.Uniform]
)

times = QuEraCodeGen.convert_time_to_SI_units(times)
Expand Down
58 changes: 25 additions & 33 deletions src/bloqade/ir/analog_circuit.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
# from numbers import Real
from typing import TYPE_CHECKING, Union
from bloqade.visualization import display_ir

if TYPE_CHECKING:
from bloqade.ir.location.base import AtomArrangement, ParallelRegister
from bloqade.ir import Sequence
from bloqade.ir.control.sequence import SequenceExpr
from bloqade.ir.location.base import AtomArrangement, ParallelRegister
from bloqade.ir.tree_print import Printer
from beartype.typing import Union
from pydantic.dataclasses import dataclass


# NOTE: this is just a dummy type bundle geometry and sequence
# information together and forward them to backends.
@dataclass(frozen=True)
class AnalogCircuit:
"""AnalogCircuit is a dummy type that bundle register and sequence together."""

def __init__(
self,
register: Union["AtomArrangement", "ParallelRegister"],
sequence: "Sequence",
):
self._sequence = sequence
self._register = register
atom_arrangement: Union[ParallelRegister, AtomArrangement]
sequence: SequenceExpr

@property
def register(self):
Expand All @@ -35,19 +31,7 @@ def register(self):
Otherwise it will be a
[`AtomArrangement`][bloqade.ir.location.base.AtomArrangement].
"""
return self._register

@property
def sequence(self):
"""Get the sequence of the program.
Returns:
Sequence: the sequence of the program.
See also [`Sequence`][bloqade.ir.control.sequence.Sequence].
"""
return self._sequence
return self.atom_arrangement

def __eq__(self, other):
if isinstance(other, AnalogCircuit):
Expand All @@ -57,22 +41,30 @@ def __eq__(self, other):

return False

def __repr__(self):
# TODO: add repr for static_params, batch_params and order
def __str__(self):
out = ""
if self._register is not None:
out += self._register.__repr__()
if self.register is not None:
out += self.register.__str__()

out += "\n"

if self._sequence is not None:
out += self._sequence.__repr__()
if self.sequence is not None:
out += self.sequence.__str__()

return out

def print_node(self):
return "AnalogCircuit"

def children(self):
return {"register": self.atom_arrangement, "sequence": self.sequence}

def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def figure(self, **assignments):
fig_reg = self._register.figure(**assignments)
fig_seq = self._sequence.figure(**assignments)
fig_reg = self.register.figure(**assignments)
fig_seq = self.sequence.figure(**assignments)
return fig_seq, fig_reg

def show(self, **assignments):
Expand Down
6 changes: 3 additions & 3 deletions src/bloqade/ir/analysis/assignment_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def visit_named_sequence(self, ast: sequence.NamedSequence):
self.visit(ast.sequence)

def visit_append_sequence(self, ast: sequence.Append):
list(map(self.visit, ast.value))
list(map(self.visit, ast.sequences))

def visit_slice_sequence(self, ast: sequence.Slice):
self.visit(ast.sequence)
Expand All @@ -99,13 +99,13 @@ def visit_named_pulse(self, ast: pulse.NamedPulse) -> Any:
self.visit(ast.pulse)

def visit_append_pulse(self, ast: pulse.Append) -> Any:
list(map(self.visit, ast.value))
list(map(self.visit, ast.pulses))

def visit_slice_pulse(self, ast: pulse.Slice) -> Any:
self.visit(ast.pulse)

def visit_field(self, ast: field.Field):
list(map(self.visit, ast.value.values()))
list(map(self.visit, ast.drives.values()))

def visit_waveform(self, ast: waveform.Waveform):
self.assignments.update(self.waveform_visitor.emit(ast))
Expand Down
8 changes: 4 additions & 4 deletions src/bloqade/ir/analysis/scan_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ def visit_assigned_run_time_vector(self, ast: field.AssignedRunTimeVector) -> An
self.vector_vars.add(ast.name)

def visit_field(self, ast: field.Field) -> Any:
list(map(self.visit, ast.value.keys()))
for value in ast.value.values():
list(map(self.visit, ast.drives.keys()))
for value in ast.drives.values():
self.visit(value)

def visit_pulse(self, ast: pulse.Pulse) -> Any:
list(map(self.visit, ast.fields.values()))

def visit_append_pulse(self, ast: pulse.Append) -> Any:
list(map(self.visit, ast.value))
list(map(self.visit, ast.pulses))

def visit_slice_pulse(self, ast: pulse.Slice) -> Any:
self.scalar_vars = self.scalar_vars.union(
Expand All @@ -223,7 +223,7 @@ def visit_sequence(self, ast: sequence.Sequence) -> Any:
list(map(self.visit, ast.pulses.values()))

def visit_append_sequence(self, ast: sequence.Append) -> Any:
list(map(self.visit, ast.value))
list(map(self.visit, ast.sequences))

def visit_slice_sequence(self, ast: sequence.Slice) -> Any:
self.scalar_vars = self.scalar_vars.union(
Expand Down
Loading

0 comments on commit a60fb8e

Please sign in to comment.