From a4680426bd6b5c06a668387a055c5b3af725bf20 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 2 Oct 2023 13:08:32 -0400 Subject: [PATCH] creating new API for `PythonFn` to stop some of the 'hacking' of attributes. --- src/bloqade/builder/waveform.py | 2 +- src/bloqade/codegen/common/assign_variables.py | 8 ++------ src/bloqade/ir/control/waveform.py | 17 +++++++++-------- tests/test_hardware_codegen.py | 4 ++-- tests/test_waveform.py | 8 ++++---- tests/test_waveform_visitor.py | 2 +- 6 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/bloqade/builder/waveform.py b/src/bloqade/builder/waveform.py index b340b40bc..8259eeca6 100644 --- a/src/bloqade/builder/waveform.py +++ b/src/bloqade/builder/waveform.py @@ -873,7 +873,7 @@ def sample( return Sample(dt, interpolation, self) def __bloqade_ir__(self): - return ir.PythonFn(self._fn, self._duration) + return ir.PythonFn.create(self._fn, self._duration) # NOTE: no double-slice or double-record diff --git a/src/bloqade/codegen/common/assign_variables.py b/src/bloqade/codegen/common/assign_variables.py index 601378bcf..6cd2298a8 100644 --- a/src/bloqade/codegen/common/assign_variables.py +++ b/src/bloqade/codegen/common/assign_variables.py @@ -91,8 +91,7 @@ def visit_poly(self, ast: waveform.Poly) -> Any: return waveform.Poly(checkpoints, duration) def visit_python_fn(self, ast: waveform.PythonFn) -> Any: - new_fn = waveform.PythonFn(ast.fn, self.scalar_visitor.emit(ast.duration)) - + duration = self.scalar_visitor.emit(ast.duration) new_parameters = [] default_param_values = dict(ast.default_param_values) @@ -107,10 +106,7 @@ def visit_python_fn(self, ast: waveform.PythonFn) -> Any: else: new_parameters.append(param) - object.__setattr__(new_fn, "parameters", new_parameters) - object.__setattr__(new_fn, "default_param_values", default_param_values) - - return new_fn + return waveform.PythonFn(ast.fn, duration, new_parameters, default_param_values) def visit_add(self, ast: waveform.Add) -> Any: return waveform.Add(self.visit(ast.left), self.visit(ast.right)) diff --git a/src/bloqade/ir/control/waveform.py b/src/bloqade/ir/control/waveform.py index 840ba76bd..a14b66f1e 100644 --- a/src/bloqade/ir/control/waveform.py +++ b/src/bloqade/ir/control/waveform.py @@ -30,7 +30,7 @@ def to_waveform(duration: ScalarType) -> Callable[[Callable], "PythonFn"]: # turn python function into a waveform instruction.""" def waveform_wrapper(fn: Callable) -> "PythonFn": - return PythonFn(fn, duration) + return PythonFn.create(fn, duration) return waveform_wrapper @@ -407,7 +407,7 @@ def children(self): return annotated_coeffs -@dataclass(init=False, frozen=True) +@dataclass(frozen=True) class PythonFn(Instruction): """ @@ -417,11 +417,12 @@ class PythonFn(Instruction): """ fn: Callable # [[float, ...], float] # f(t) -> value - parameters: List[Union[Variable, AssignedVariable]] # come from ast inspect duration: Scalar + parameters: List[Union[Variable, AssignedVariable]] # come from ast inspect default_param_values: Dict[str, Decimal] - def __init__(self, fn: Callable, duration: Any): + @staticmethod + def create(fn: Callable, duration: ScalarType) -> "PythonFn": signature = inspect.getfullargspec(fn) if signature.varargs is not None: @@ -447,10 +448,10 @@ def __init__(self, fn: Callable, duration: Any): variables += signature.args[1:] variables += signature.kwonlyargs - object.__setattr__(self, "fn", fn) - object.__setattr__(self, "parameters", list(map(var, variables))) - object.__setattr__(self, "duration", cast(duration)) - object.__setattr__(self, "default_param_values", default_param_values) + parameters = list(map(var, variables)) + duration = cast(duration) + + return PythonFn(fn, duration, parameters, default_param_values) @cached_property def _hash_value(self) -> int: diff --git a/tests/test_hardware_codegen.py b/tests/test_hardware_codegen.py index a45084a2e..1525beef1 100644 --- a/tests/test_hardware_codegen.py +++ b/tests/test_hardware_codegen.py @@ -180,7 +180,7 @@ def my_cos(time): assert my_cos(1) == np.cos(1) - wv = PythonFn(my_cos, duration=1.0) + wv = PythonFn.create(my_cos, duration=1.0) dt = cast(0.1) wf = Sample(wv, Interpolation.Constant, dt) @@ -584,7 +584,7 @@ def my_cos(time): assert my_cos(1) == np.cos(1) - wv = PythonFn(my_cos, duration=1.0) + wv = PythonFn.create(my_cos, duration=1.0) dt = cast(0.1) wf = Sample(wv, Interpolation.Linear, dt) diff --git a/tests/test_waveform.py b/tests/test_waveform.py index 9d733d06c..7507ebe46 100644 --- a/tests/test_waveform.py +++ b/tests/test_waveform.py @@ -102,12 +102,12 @@ def my_func3(time, omega, **phi): assert my_func3(3, 2) == 3 with pytest.raises(ValueError): - PythonFn(my_func2, duration=1.0) + PythonFn.create(my_func2, duration=1.0) with pytest.raises(ValueError): - PythonFn(my_func3, duration=1.0) + PythonFn.create(my_func3, duration=1.0) - wf = PythonFn(my_func, duration=1.0) + wf = PythonFn.create(my_func, duration=1.0) awf = annot_my_func assert wf.eval_decimal(Decimal("0.56"), omega=1, amplitude=4) == awf.eval_decimal( @@ -462,7 +462,7 @@ def my_cos(time): assert my_cos(1) == np.cos(1) - wv = PythonFn(my_cos, duration=1.0) + wv = PythonFn.create(my_cos, duration=1.0) dt = cast(0.1) wf = Sample(wv, Interpolation.Constant, dt) diff --git a/tests/test_waveform_visitor.py b/tests/test_waveform_visitor.py index a421e1054..8aafda6b1 100644 --- a/tests/test_waveform_visitor.py +++ b/tests/test_waveform_visitor.py @@ -32,7 +32,7 @@ wv_scale = Scale(cast(4), wv_linear) wv_add = Add(wv_linear, wv_constant) wv_record = Record(wv_linear, cast("tst")) -wv_python = PythonFn(lambda x: x**2, duration=0.5) +wv_python = PythonFn.create(lambda x: x**2, duration=0.5) wv_sample = Sample(wv_linear, Interpolation("linear"), dt=cast(0.5)) wv_poly = Poly(coeffs=[0.1, 0.2], duration=0.5) wv_smooth = Smooth(radius=cast(0.5), kernel=GaussianKernel, waveform=wv_linear)