Skip to content

Commit

Permalink
creating new API for PythonFn to stop some of the 'hacking' of attr…
Browse files Browse the repository at this point in the history
…ibutes.
  • Loading branch information
weinbe58 committed Oct 2, 2023
1 parent c7d1820 commit a468042
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/bloqade/builder/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def sample(
return Sample(dt, interpolation, self)

def __bloqade_ir__(self):
return ir.PythonFn(self._fn, self._duration)
return ir.PythonFn.create(self._fn, self._duration)


# NOTE: no double-slice or double-record
Expand Down
8 changes: 2 additions & 6 deletions src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def visit_poly(self, ast: waveform.Poly) -> Any:
return waveform.Poly(checkpoints, duration)

def visit_python_fn(self, ast: waveform.PythonFn) -> Any:
new_fn = waveform.PythonFn(ast.fn, self.scalar_visitor.emit(ast.duration))

duration = self.scalar_visitor.emit(ast.duration)
new_parameters = []
default_param_values = dict(ast.default_param_values)

Expand All @@ -107,10 +106,7 @@ def visit_python_fn(self, ast: waveform.PythonFn) -> Any:
else:
new_parameters.append(param)

object.__setattr__(new_fn, "parameters", new_parameters)
object.__setattr__(new_fn, "default_param_values", default_param_values)

return new_fn
return waveform.PythonFn(ast.fn, duration, new_parameters, default_param_values)

def visit_add(self, ast: waveform.Add) -> Any:
return waveform.Add(self.visit(ast.left), self.visit(ast.right))
Expand Down
17 changes: 9 additions & 8 deletions src/bloqade/ir/control/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def to_waveform(duration: ScalarType) -> Callable[[Callable], "PythonFn"]:
# turn python function into a waveform instruction."""

def waveform_wrapper(fn: Callable) -> "PythonFn":
return PythonFn(fn, duration)
return PythonFn.create(fn, duration)

return waveform_wrapper

Expand Down Expand Up @@ -407,7 +407,7 @@ def children(self):
return annotated_coeffs


@dataclass(init=False, frozen=True)
@dataclass(frozen=True)
class PythonFn(Instruction):
"""
Expand All @@ -417,11 +417,12 @@ class PythonFn(Instruction):
"""

fn: Callable # [[float, ...], float] # f(t) -> value
parameters: List[Union[Variable, AssignedVariable]] # come from ast inspect
duration: Scalar
parameters: List[Union[Variable, AssignedVariable]] # come from ast inspect
default_param_values: Dict[str, Decimal]

def __init__(self, fn: Callable, duration: Any):
@staticmethod
def create(fn: Callable, duration: ScalarType) -> "PythonFn":
signature = inspect.getfullargspec(fn)

if signature.varargs is not None:
Expand All @@ -447,10 +448,10 @@ def __init__(self, fn: Callable, duration: Any):
variables += signature.args[1:]
variables += signature.kwonlyargs

object.__setattr__(self, "fn", fn)
object.__setattr__(self, "parameters", list(map(var, variables)))
object.__setattr__(self, "duration", cast(duration))
object.__setattr__(self, "default_param_values", default_param_values)
parameters = list(map(var, variables))
duration = cast(duration)

return PythonFn(fn, duration, parameters, default_param_values)

@cached_property
def _hash_value(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hardware_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def my_cos(time):

assert my_cos(1) == np.cos(1)

wv = PythonFn(my_cos, duration=1.0)
wv = PythonFn.create(my_cos, duration=1.0)
dt = cast(0.1)

wf = Sample(wv, Interpolation.Constant, dt)
Expand Down Expand Up @@ -584,7 +584,7 @@ def my_cos(time):

assert my_cos(1) == np.cos(1)

wv = PythonFn(my_cos, duration=1.0)
wv = PythonFn.create(my_cos, duration=1.0)
dt = cast(0.1)

wf = Sample(wv, Interpolation.Linear, dt)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ def my_func3(time, omega, **phi):
assert my_func3(3, 2) == 3

with pytest.raises(ValueError):
PythonFn(my_func2, duration=1.0)
PythonFn.create(my_func2, duration=1.0)

with pytest.raises(ValueError):
PythonFn(my_func3, duration=1.0)
PythonFn.create(my_func3, duration=1.0)

wf = PythonFn(my_func, duration=1.0)
wf = PythonFn.create(my_func, duration=1.0)
awf = annot_my_func

assert wf.eval_decimal(Decimal("0.56"), omega=1, amplitude=4) == awf.eval_decimal(
Expand Down Expand Up @@ -462,7 +462,7 @@ def my_cos(time):

assert my_cos(1) == np.cos(1)

wv = PythonFn(my_cos, duration=1.0)
wv = PythonFn.create(my_cos, duration=1.0)
dt = cast(0.1)

wf = Sample(wv, Interpolation.Constant, dt)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_waveform_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
wv_scale = Scale(cast(4), wv_linear)
wv_add = Add(wv_linear, wv_constant)
wv_record = Record(wv_linear, cast("tst"))
wv_python = PythonFn(lambda x: x**2, duration=0.5)
wv_python = PythonFn.create(lambda x: x**2, duration=0.5)
wv_sample = Sample(wv_linear, Interpolation("linear"), dt=cast(0.5))
wv_poly = Poly(coeffs=[0.1, 0.2], duration=0.5)
wv_smooth = Smooth(radius=cast(0.5), kernel=GaussianKernel, waveform=wv_linear)
Expand Down

0 comments on commit a468042

Please sign in to comment.