Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing ir printing #629

Merged
merged 39 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2875840
making all IR pydantic. dataclasses
weinbe58 Sep 24, 2023
f6b467b
removing some reprs
weinbe58 Sep 24, 2023
bdfe479
removing old repr for waveforms.
weinbe58 Sep 24, 2023
834a765
add drop-in node for complicated dict in IR.
weinbe58 Sep 24, 2023
4577705
fixing printers for field.
weinbe58 Sep 24, 2023
257c75c
removing all __str__.
weinbe58 Sep 24, 2023
bf2a729
renaming field.
weinbe58 Sep 24, 2023
93cdd51
renaming field.
weinbe58 Sep 24, 2023
8abc9a2
renaming field.
weinbe58 Sep 24, 2023
dde5589
fixing pretty-printer for locations.
weinbe58 Sep 24, 2023
a3d338c
fixing issues with renaming field attr.
weinbe58 Sep 24, 2023
f3011b3
adding tests for lattice pprint.
weinbe58 Sep 24, 2023
567a2d1
adding equality for `KeyValuePair` for testing
weinbe58 Sep 24, 2023
19abefc
fixing tests for fields printer
weinbe58 Sep 24, 2023
a52db2f
fixing more tests.
weinbe58 Sep 24, 2023
6d7f0eb
fixing tests
weinbe58 Sep 25, 2023
cb2f47d
fixing printing for AnalogCircuit.
weinbe58 Sep 25, 2023
e0f6704
removing commented code.
weinbe58 Sep 25, 2023
315f2ae
removing commented code.
weinbe58 Sep 25, 2023
b4b357a
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 25, 2023
8167cd2
modify traits for builder.
weinbe58 Sep 25, 2023
5644b4b
removing `KeyValuePair`
weinbe58 Sep 25, 2023
9b0ac3b
fixing field test.
weinbe58 Sep 25, 2023
cfa7198
fixing tests for new printing
weinbe58 Sep 25, 2023
f9f96b8
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 25, 2023
e1bcf35
fixing emulator ir tests.
weinbe58 Sep 25, 2023
386318e
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 25, 2023
938e46f
moving pprint to `Routine`.
weinbe58 Sep 25, 2023
c4f5b6b
cleaning up if statement
weinbe58 Sep 26, 2023
94473a7
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 26, 2023
9ccf042
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 26, 2023
0650787
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 27, 2023
e6d035c
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 28, 2023
74df197
Merge branch 'fixing-IR-printing' of https://github.com/QuEraComputin…
weinbe58 Sep 28, 2023
2b74cd2
removing `repr=False`
weinbe58 Sep 28, 2023
dc4727a
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 28, 2023
f9da5dc
adding dumb inline printing for scalars.
weinbe58 Sep 28, 2023
e82003f
adding unit tests for printer.
weinbe58 Sep 28, 2023
f3e0fd8
Merge branch 'main' into fixing-IR-printing
weinbe58 Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/bloqade/builder/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Optional, Union, List
from numbers import Real

from bloqade.builder.parse.trait import CompileJSON, Parse
from bloqade.builder.parse.trait import Parse, Show

ParamType = Union[Real, List[Real]]


class Builder(CompileJSON, Parse):
class Builder(Parse, Show):
__match_args__ = ("__parent__",)

def __init__(
self,
parent: Optional["Builder"] = None,
) -> None:
self.__parent__ = parent

def __str__(self):
return str(self.parse())
7 changes: 1 addition & 6 deletions src/bloqade/builder/parse/trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ def n_atoms(self: "Builder"):
else:
return register.n_atoms

def __repr__(self: "Builder"):
from .builder import Parser

analog_circ, metas = Parser().parse(self)

return repr(analog_circ) + "\n" + repr(metas)

class Show:
def show(self, batch_id: int = 0):
display_builder(self, batch_id)
4 changes: 2 additions & 2 deletions src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def visit_pulse(self, ast: pulse.Pulse) -> pulse.Pulse:
)

def visit_append_pulse(self, ast: pulse.Append) -> pulse.Append:
return pulse.Append(list(map(self.visit, ast.value)))
return pulse.Append(list(map(self.visit, ast.pulses)))

def visit_slice_pulse(self, ast: pulse.Slice) -> pulse.Slice:
return pulse.Slice(self.visit(ast.pulse), self.scalar_visitor(ast.interval))
Expand All @@ -268,7 +268,7 @@ def visit_named_pulse(self, ast: pulse.NamedPulse) -> Any:

def visit_field(self, ast: field.Field) -> field.Field:
return field.Field(
{self.visit(sm): self.visit(wf) for sm, wf in ast.value.items()}
{self.visit(sm): self.visit(wf) for sm, wf in ast.drives.items()}
)

def visit_uniform_modulation(self, ast: field.UniformModulation) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/codegen/common/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def visit_assigned_run_time_vector(
def visit_field(self, ast: field.Field) -> Any:
return {
"field": {
"value": [(self.visit(k), self.visit(v)) for k, v in ast.value.items()]
"value": [(self.visit(k), self.visit(v)) for k, v in ast.drives.items()]
}
}

Expand Down
34 changes: 17 additions & 17 deletions src/bloqade/codegen/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def visit_detuning(self, ast: Optional[Field]):

terms = []

if len(ast.value) <= self.n_atoms:
for sm, wf in ast.value.items():
if len(ast.drives) <= self.n_atoms:
for sm, wf in ast.drives.items():
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -181,7 +181,7 @@ def visit_detuning(self, ast: Optional[Field]):
)
)
else:
target_atom_dict = {sm: self.visit(sm) for sm in ast.value.keys()}
target_atom_dict = {sm: self.visit(sm) for sm in ast.drives.keys()}

for atom in range(self.n_atoms):
if not any(atom in value for value in target_atom_dict.values()):
Expand All @@ -190,7 +190,7 @@ def visit_detuning(self, ast: Optional[Field]):
wf = sum(
(
target_atom_dict[sm][atom] * wf
for sm, wf in ast.value.items()
for sm, wf in ast.drives.items()
if atom in target_atom_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand All @@ -217,8 +217,8 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
if amplitude is None:
return terms

if phase is None and len(amplitude.value) <= self.n_atoms:
for sm, wf in amplitude.value.items():
if phase is None and len(amplitude.drives) <= self.n_atoms:
for sm, wf in amplitude.drives.items():
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -243,7 +243,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
)
elif phase is None: # fully local real rabi fields
amplitude_target_atoms_dict = {
sm: self.visit(sm) for sm in amplitude.value.keys()
sm: self.visit(sm) for sm in amplitude.drives.keys()
}
for atom in range(self.n_atoms):
if not any(
Expand All @@ -254,7 +254,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
for sm, wf in amplitude.drives.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand All @@ -274,16 +274,16 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
)
)
elif (
len(phase.value) == 1
and UniformModulation() in phase.value
and len(amplitude.value) <= self.n_atoms
len(phase.drives) == 1
and UniformModulation() in phase.drives
and len(amplitude.drives) <= self.n_atoms
):
(phase_waveform,) = phase.value.values()
(phase_waveform,) = phase.drives.values()
rabi_phase = self.waveform_compiler.emit(phase_waveform)
self.duration = max(
float(phase_waveform.duration(**self.assignments)), self.duration
)
for sm, wf in amplitude.value.items():
for sm, wf in amplitude.drives.items():
self.duration = max(
float(wf.duration(**self.assignments)), self.duration
)
Expand All @@ -308,9 +308,9 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
)
)
else:
phase_target_atoms_dict = {sm: self.visit(sm) for sm in phase.value.keys()}
phase_target_atoms_dict = {sm: self.visit(sm) for sm in phase.drives.keys()}
amplitude_target_atoms_dict = {
sm: self.visit(sm) for sm in amplitude.value.keys()
sm: self.visit(sm) for sm in amplitude.drives.keys()
}

terms = []
Expand All @@ -323,7 +323,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
phase_wf = sum(
(
phase_target_atoms_dict[sm][atom] * wf
for sm, wf in phase.value.items()
for sm, wf in phase.drives.items()
if atom in phase_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand All @@ -332,7 +332,7 @@ def visit_rabi(self, amplitude: Optional[Field], phase: Optional[Field]):
amplitude_wf = sum(
(
amplitude_target_atoms_dict[sm][atom] * wf
for sm, wf in amplitude.value.items()
for sm, wf in amplitude.drives.items()
if atom in amplitude_target_atoms_dict[sm]
),
start=waveform.Constant(0.0, 0.0),
Expand Down
24 changes: 12 additions & 12 deletions src/bloqade/codegen/hardware/quera.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def visit_assigned_run_time_vector(self, ast: field.AssignedRunTimeVector) -> An
self.post_visit_spatial_modulation(lattice_site_coefficients)

def visit_detuning(self, ast: field.Field) -> Any:
if len(ast.value) == 1 and field.Uniform in ast.value:
if len(ast.drives) == 1 and field.Uniform in ast.drives:
times, values = PiecewiseLinearCodeGen(self.assignments).visit(
ast.value[field.Uniform]
ast.drives[field.Uniform]
)

times = QuEraCodeGen.convert_time_to_SI_units(times)
Expand All @@ -429,8 +429,8 @@ def visit_detuning(self, ast: field.Field) -> Any:
self.detuning = task_spec.Detuning(
global_=task_spec.GlobalField(times=times, values=values)
)
elif len(ast.value) == 1:
((spatial_modulation, waveform),) = ast.value.items()
elif len(ast.drives) == 1:
((spatial_modulation, waveform),) = ast.drives.items()

times, values = PiecewiseLinearCodeGen(self.assignments).visit(waveform)

Expand All @@ -446,18 +446,18 @@ def visit_detuning(self, ast: field.Field) -> Any:
lattice_site_coefficients=self.lattice_site_coefficients,
),
)
elif len(ast.value) == 2 and field.Uniform in ast.value:
elif len(ast.drives) == 2 and field.Uniform in ast.drives:
# will only be two keys
for k in ast.value.keys():
for k in ast.drives.keys():
if k == field.Uniform:
global_times, global_values = PiecewiseLinearCodeGen(
self.assignments
).visit(ast.value[field.Uniform])
).visit(ast.drives[field.Uniform])
else: # can be field.RunTimeVector or field.ScaledLocations
spatial_modulation = k
local_times, local_values = PiecewiseLinearCodeGen(
self.assignments
).visit(ast.value[k])
).visit(ast.drives[k])

self.visit(spatial_modulation) # just visit the non-uniform locations

Expand All @@ -483,9 +483,9 @@ def visit_detuning(self, ast: field.Field) -> Any:
)

def visit_rabi_amplitude(self, ast: field.Field) -> Any:
if len(ast.value) == 1 and field.Uniform in ast.value:
if len(ast.drives) == 1 and field.Uniform in ast.drives:
times, values = PiecewiseLinearCodeGen(self.assignments).visit(
ast.value[field.Uniform]
ast.drives[field.Uniform]
)

times = QuEraCodeGen.convert_time_to_SI_units(times)
Expand All @@ -502,9 +502,9 @@ def visit_rabi_amplitude(self, ast: field.Field) -> Any:
)

def visit_rabi_phase(self, ast: field.Field) -> Any:
if len(ast.value) == 1 and field.Uniform in ast.value: # has to be global
if len(ast.drives) == 1 and field.Uniform in ast.drives: # has to be global
times, values = PiecewiseConstantCodeGen(self.assignments).visit(
ast.value[field.Uniform]
ast.drives[field.Uniform]
)

times = QuEraCodeGen.convert_time_to_SI_units(times)
Expand Down
58 changes: 25 additions & 33 deletions src/bloqade/ir/analog_circuit.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
# from numbers import Real
from typing import TYPE_CHECKING, Union
from bloqade.visualization import display_ir

if TYPE_CHECKING:
from bloqade.ir.location.base import AtomArrangement, ParallelRegister
from bloqade.ir import Sequence
from bloqade.ir.control.sequence import SequenceExpr
from bloqade.ir.location.base import AtomArrangement, ParallelRegister
from bloqade.ir.tree_print import Printer
from beartype.typing import Union
from pydantic.dataclasses import dataclass


# NOTE: this is just a dummy type bundle geometry and sequence
# information together and forward them to backends.
@dataclass(frozen=True)
class AnalogCircuit:
"""AnalogCircuit is a dummy type that bundle register and sequence together."""

def __init__(
self,
register: Union["AtomArrangement", "ParallelRegister"],
sequence: "Sequence",
):
self._sequence = sequence
self._register = register
atom_arrangement: Union[ParallelRegister, AtomArrangement]
sequence: SequenceExpr

@property
def register(self):
Expand All @@ -35,19 +31,7 @@ def register(self):
Otherwise it will be a
[`AtomArrangement`][bloqade.ir.location.base.AtomArrangement].
"""
return self._register

@property
def sequence(self):
"""Get the sequence of the program.


Returns:
Sequence: the sequence of the program.
See also [`Sequence`][bloqade.ir.control.sequence.Sequence].

"""
return self._sequence
return self.atom_arrangement

def __eq__(self, other):
if isinstance(other, AnalogCircuit):
Expand All @@ -57,22 +41,30 @@ def __eq__(self, other):

return False

def __repr__(self):
# TODO: add repr for static_params, batch_params and order
def __str__(self):
out = ""
if self._register is not None:
out += self._register.__repr__()
if self.register is not None:
out += self.register.__str__()

out += "\n"

if self._sequence is not None:
out += self._sequence.__repr__()
if self.sequence is not None:
out += self.sequence.__str__()

return out

def print_node(self):
return "AnalogCircuit"

def children(self):
return {"register": self.atom_arrangement, "sequence": self.sequence}

def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

def figure(self, **assignments):
fig_reg = self._register.figure(**assignments)
fig_seq = self._sequence.figure(**assignments)
fig_reg = self.register.figure(**assignments)
fig_seq = self.sequence.figure(**assignments)
return fig_seq, fig_reg

def show(self, **assignments):
Expand Down
6 changes: 3 additions & 3 deletions src/bloqade/ir/analysis/assignment_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def visit_named_sequence(self, ast: sequence.NamedSequence):
self.visit(ast.sequence)

def visit_append_sequence(self, ast: sequence.Append):
list(map(self.visit, ast.value))
list(map(self.visit, ast.sequences))

def visit_slice_sequence(self, ast: sequence.Slice):
self.visit(ast.sequence)
Expand All @@ -99,13 +99,13 @@ def visit_named_pulse(self, ast: pulse.NamedPulse) -> Any:
self.visit(ast.pulse)

def visit_append_pulse(self, ast: pulse.Append) -> Any:
list(map(self.visit, ast.value))
list(map(self.visit, ast.pulses))

def visit_slice_pulse(self, ast: pulse.Slice) -> Any:
self.visit(ast.pulse)

def visit_field(self, ast: field.Field):
list(map(self.visit, ast.value.values()))
list(map(self.visit, ast.drives.values()))

def visit_waveform(self, ast: waveform.Waveform):
self.assignments.update(self.waveform_visitor.emit(ast))
Expand Down
8 changes: 4 additions & 4 deletions src/bloqade/ir/analysis/scan_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ def visit_assigned_run_time_vector(self, ast: field.AssignedRunTimeVector) -> An
self.vector_vars.add(ast.name)

def visit_field(self, ast: field.Field) -> Any:
list(map(self.visit, ast.value.keys()))
for value in ast.value.values():
list(map(self.visit, ast.drives.keys()))
for value in ast.drives.values():
self.visit(value)

def visit_pulse(self, ast: pulse.Pulse) -> Any:
list(map(self.visit, ast.fields.values()))

def visit_append_pulse(self, ast: pulse.Append) -> Any:
list(map(self.visit, ast.value))
list(map(self.visit, ast.pulses))

def visit_slice_pulse(self, ast: pulse.Slice) -> Any:
self.scalar_vars = self.scalar_vars.union(
Expand All @@ -223,7 +223,7 @@ def visit_sequence(self, ast: sequence.Sequence) -> Any:
list(map(self.visit, ast.pulses.values()))

def visit_append_sequence(self, ast: sequence.Append) -> Any:
list(map(self.visit, ast.value))
list(map(self.visit, ast.sequences))

def visit_slice_sequence(self, ast: sequence.Slice) -> Any:
self.scalar_vars = self.scalar_vars.union(
Expand Down
Loading