Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Field canonicalization #661

Merged
merged 13 commits into from
Oct 4, 2023
2 changes: 1 addition & 1 deletion src/bloqade/builder/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def sample(
return Sample(dt, interpolation, self)

def __bloqade_ir__(self):
return ir.PythonFn(self._fn, self._duration)
return ir.PythonFn.create(self._fn, self._duration)


# NOTE: no double-slice or double-record
Expand Down
18 changes: 15 additions & 3 deletions src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,21 @@
return waveform.Poly(checkpoints, duration)

def visit_python_fn(self, ast: waveform.PythonFn) -> Any:
new_ast = waveform.PythonFn(ast.fn, self.scalar_visitor.emit(ast.duration))
new_ast.parameters = list(map(self.scalar_visitor.emit, ast.parameters))
return new_ast
duration = self.scalar_visitor.emit(ast.duration)
new_parameters = []
default_param_values = dict(ast.default_param_values)

new_parameters = list(map(self.scalar_visitor.emit, ast.parameters))

# remove default parameters that are overwritten by the assignment
for param in new_parameters:
if (
isinstance(param, scalar.AssignedVariable)
and param.name in default_param_values
):
default_param_values.pop(param.name)

Check warning on line 106 in src/bloqade/codegen/common/assign_variables.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/codegen/common/assign_variables.py#L106

Added line #L106 was not covered by tests

return waveform.PythonFn(ast.fn, duration, new_parameters, default_param_values)

def visit_add(self, ast: waveform.Add) -> Any:
return waveform.Add(self.visit(ast.left), self.visit(ast.right))
Expand Down
52 changes: 47 additions & 5 deletions src/bloqade/ir/control/field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
from ..scalar import Scalar, cast
from ..tree_print import Printer
from .waveform import Waveform
Expand Down Expand Up @@ -107,8 +108,12 @@ class AssignedRunTimeVector(SpatialModulation):
name: str
value: List[Decimal]

@cached_property
def _hash_value(self):
return hash(self.name) ^ hash(tuple(self.value)) ^ hash(self.__class__)

def __hash__(self) -> int:
return hash(self.name) ^ hash(self.__class__) ^ hash(tuple(self.value))
return self._hash_value

def print_node(self):
return "AssgiendRunTimeVector"
Expand Down Expand Up @@ -143,9 +148,13 @@ def __init__(self, pairs):
value[k] = cast(v)
self.value = value

def __hash__(self) -> int:
@cached_property
def _hash_value(self) -> int:
return hash(frozenset(self.value.items())) ^ hash(self.__class__)

def __hash__(self) -> int:
return self._hash_value

def _get_data(self, **assignments):
names = []
scls = []
Expand Down Expand Up @@ -197,7 +206,7 @@ def children(self):
return {"modulation": self.modulation, "waveform": self.waveform}


@dataclass
@dataclass(frozen=True)
class Field(FieldExpr):
"""Field node in the IR. Which contains collection(s) of
[`Waveform`][bloqade.ir.control.waveform.Waveform]
Expand All @@ -209,9 +218,42 @@ class Field(FieldExpr):

drives: Dict[SpatialModulation, Waveform]

def __hash__(self) -> int:
@cached_property
def _hash_value(self):
return hash(frozenset(self.drives.items())) ^ hash(self.__class__)

def __hash__(self) -> int:
return self._hash_value

def canonicalize(self) -> "Field":
"""
Canonicalize the Field by merging `ScaledLocation` nodes with the same waveform.
"""
reversed_dirves = {}

for sm, wf in self.drives.items():
reversed_dirves[wf] = reversed_dirves.get(wf, []) + [sm]

drives = {}

for wf, sms in reversed_dirves.items():
new_sm = [sm for sm in sms if not isinstance(sm, ScaledLocations)]
scaled_locations_sm = [sm for sm in sms if isinstance(sm, ScaledLocations)]

new_mask = {}

for ele in scaled_locations_sm:
for loc, scl in ele.value.items():
new_mask[loc] = new_mask.get(loc, 0) + cast(scl)

if new_mask:
new_sm += [ScaledLocations(new_mask)]

for sm in new_sm:
drives[sm] = wf

return Field(drives)

def add(self, other):
if not isinstance(other, Field):
raise ValueError(f"Cannot add Field and {other.__class__}")
Expand All @@ -226,7 +268,7 @@ def add(self, other):
else:
out.drives[spatial_modulation] = waveform

return out
return out.canonicalize()

def print_node(self):
return "Field"
Expand Down
Loading
Loading