Skip to content

Commit

Permalink
Field canonicalization (#661)
Browse files Browse the repository at this point in the history
* adding hashing for waveforms

* fixing hashing of waveforms.

* adding canonicalization to field with waveform hashing.

* only merging `ScaledLocation` nodes in field canonicalizaiton.

* adding back hash for field object.

* add tests for hash.

* creating new API for `PythonFn` to stop some of the 'hacking' of attributes.

* adding `cached_property` for  `_hash_value`

* simplifying assign.

* adding docstring.
  • Loading branch information
weinbe58 authored Oct 4, 2023
1 parent cdd2bfe commit 3659e3d
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 146 deletions.
2 changes: 1 addition & 1 deletion src/bloqade/builder/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def sample(
return Sample(dt, interpolation, self)

def __bloqade_ir__(self):
return ir.PythonFn(self._fn, self._duration)
return ir.PythonFn.create(self._fn, self._duration)


# NOTE: no double-slice or double-record
Expand Down
18 changes: 15 additions & 3 deletions src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,21 @@ def visit_poly(self, ast: waveform.Poly) -> Any:
return waveform.Poly(checkpoints, duration)

def visit_python_fn(self, ast: waveform.PythonFn) -> Any:
new_ast = waveform.PythonFn(ast.fn, self.scalar_visitor.emit(ast.duration))
new_ast.parameters = list(map(self.scalar_visitor.emit, ast.parameters))
return new_ast
duration = self.scalar_visitor.emit(ast.duration)
new_parameters = []
default_param_values = dict(ast.default_param_values)

new_parameters = list(map(self.scalar_visitor.emit, ast.parameters))

# remove default parameters that are overwritten by the assignment
for param in new_parameters:
if (
isinstance(param, scalar.AssignedVariable)
and param.name in default_param_values
):
default_param_values.pop(param.name)

return waveform.PythonFn(ast.fn, duration, new_parameters, default_param_values)

def visit_add(self, ast: waveform.Add) -> Any:
return waveform.Add(self.visit(ast.left), self.visit(ast.right))
Expand Down
52 changes: 47 additions & 5 deletions src/bloqade/ir/control/field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
from ..scalar import Scalar, cast
from ..tree_print import Printer
from .waveform import Waveform
Expand Down Expand Up @@ -107,8 +108,12 @@ class AssignedRunTimeVector(SpatialModulation):
name: str
value: List[Decimal]

@cached_property
def _hash_value(self):
return hash(self.name) ^ hash(tuple(self.value)) ^ hash(self.__class__)

def __hash__(self) -> int:
return hash(self.name) ^ hash(self.__class__) ^ hash(tuple(self.value))
return self._hash_value

def print_node(self):
return "AssgiendRunTimeVector"
Expand Down Expand Up @@ -143,9 +148,13 @@ def __init__(self, pairs):
value[k] = cast(v)
self.value = value

def __hash__(self) -> int:
@cached_property
def _hash_value(self) -> int:
return hash(frozenset(self.value.items())) ^ hash(self.__class__)

def __hash__(self) -> int:
return self._hash_value

def _get_data(self, **assignments):
names = []
scls = []
Expand Down Expand Up @@ -197,7 +206,7 @@ def children(self):
return {"modulation": self.modulation, "waveform": self.waveform}


@dataclass
@dataclass(frozen=True)
class Field(FieldExpr):
"""Field node in the IR. Which contains collection(s) of
[`Waveform`][bloqade.ir.control.waveform.Waveform]
Expand All @@ -209,9 +218,42 @@ class Field(FieldExpr):

drives: Dict[SpatialModulation, Waveform]

def __hash__(self) -> int:
@cached_property
def _hash_value(self):
return hash(frozenset(self.drives.items())) ^ hash(self.__class__)

def __hash__(self) -> int:
return self._hash_value

def canonicalize(self) -> "Field":
"""
Canonicalize the Field by merging `ScaledLocation` nodes with the same waveform.
"""
reversed_dirves = {}

for sm, wf in self.drives.items():
reversed_dirves[wf] = reversed_dirves.get(wf, []) + [sm]

drives = {}

for wf, sms in reversed_dirves.items():
new_sm = [sm for sm in sms if not isinstance(sm, ScaledLocations)]
scaled_locations_sm = [sm for sm in sms if isinstance(sm, ScaledLocations)]

new_mask = {}

for ele in scaled_locations_sm:
for loc, scl in ele.value.items():
new_mask[loc] = new_mask.get(loc, 0) + cast(scl)

if new_mask:
new_sm += [ScaledLocations(new_mask)]

for sm in new_sm:
drives[sm] = wf

return Field(drives)

def add(self, other):
if not isinstance(other, Field):
raise ValueError(f"Cannot add Field and {other.__class__}")
Expand All @@ -226,7 +268,7 @@ def add(self, other):
else:
out.drives[spatial_modulation] = waveform

return out
return out.canonicalize()

def print_node(self):
return "Field"
Expand Down
Loading

0 comments on commit 3659e3d

Please sign in to comment.