diff --git a/src/bloqade/builder/waveform.py b/src/bloqade/builder/waveform.py index fa4d7d092..781aee852 100644 --- a/src/bloqade/builder/waveform.py +++ b/src/bloqade/builder/waveform.py @@ -1019,7 +1019,7 @@ def sample( return Sample(dt, interpolation, self) def __bloqade_ir__(self): - return ir.PythonFn(self._fn, self._duration) + return ir.PythonFn.create(self._fn, self._duration) # NOTE: no double-slice or double-record diff --git a/src/bloqade/codegen/common/assign_variables.py b/src/bloqade/codegen/common/assign_variables.py index 0308493bd..83c1b5a16 100644 --- a/src/bloqade/codegen/common/assign_variables.py +++ b/src/bloqade/codegen/common/assign_variables.py @@ -91,9 +91,21 @@ def visit_poly(self, ast: waveform.Poly) -> Any: return waveform.Poly(checkpoints, duration) def visit_python_fn(self, ast: waveform.PythonFn) -> Any: - new_ast = waveform.PythonFn(ast.fn, self.scalar_visitor.emit(ast.duration)) - new_ast.parameters = list(map(self.scalar_visitor.emit, ast.parameters)) - return new_ast + duration = self.scalar_visitor.emit(ast.duration) + new_parameters = [] + default_param_values = dict(ast.default_param_values) + + new_parameters = list(map(self.scalar_visitor.emit, ast.parameters)) + + # remove default parameters that are overwritten by the assignment + for param in new_parameters: + if ( + isinstance(param, scalar.AssignedVariable) + and param.name in default_param_values + ): + default_param_values.pop(param.name) + + return waveform.PythonFn(ast.fn, duration, new_parameters, default_param_values) def visit_add(self, ast: waveform.Add) -> Any: return waveform.Add(self.visit(ast.left), self.visit(ast.right)) diff --git a/src/bloqade/ir/control/field.py b/src/bloqade/ir/control/field.py index 89f8db324..9a357c396 100644 --- a/src/bloqade/ir/control/field.py +++ b/src/bloqade/ir/control/field.py @@ -1,3 +1,4 @@ +from functools import cached_property from ..scalar import Scalar, cast from ..tree_print import Printer from .waveform import Waveform @@ -107,8 +108,12 @@ class AssignedRunTimeVector(SpatialModulation): name: str value: List[Decimal] + @cached_property + def _hash_value(self): + return hash(self.name) ^ hash(tuple(self.value)) ^ hash(self.__class__) + def __hash__(self) -> int: - return hash(self.name) ^ hash(self.__class__) ^ hash(tuple(self.value)) + return self._hash_value def print_node(self): return "AssgiendRunTimeVector" @@ -143,9 +148,13 @@ def __init__(self, pairs): value[k] = cast(v) self.value = value - def __hash__(self) -> int: + @cached_property + def _hash_value(self) -> int: return hash(frozenset(self.value.items())) ^ hash(self.__class__) + def __hash__(self) -> int: + return self._hash_value + def _get_data(self, **assignments): names = [] scls = [] @@ -197,7 +206,7 @@ def children(self): return {"modulation": self.modulation, "waveform": self.waveform} -@dataclass +@dataclass(frozen=True) class Field(FieldExpr): """Field node in the IR. Which contains collection(s) of [`Waveform`][bloqade.ir.control.waveform.Waveform] @@ -209,9 +218,42 @@ class Field(FieldExpr): drives: Dict[SpatialModulation, Waveform] - def __hash__(self) -> int: + @cached_property + def _hash_value(self): return hash(frozenset(self.drives.items())) ^ hash(self.__class__) + def __hash__(self) -> int: + return self._hash_value + + def canonicalize(self) -> "Field": + """ + Canonicalize the Field by merging `ScaledLocation` nodes with the same waveform. + """ + reversed_dirves = {} + + for sm, wf in self.drives.items(): + reversed_dirves[wf] = reversed_dirves.get(wf, []) + [sm] + + drives = {} + + for wf, sms in reversed_dirves.items(): + new_sm = [sm for sm in sms if not isinstance(sm, ScaledLocations)] + scaled_locations_sm = [sm for sm in sms if isinstance(sm, ScaledLocations)] + + new_mask = {} + + for ele in scaled_locations_sm: + for loc, scl in ele.value.items(): + new_mask[loc] = new_mask.get(loc, 0) + cast(scl) + + if new_mask: + new_sm += [ScaledLocations(new_mask)] + + for sm in new_sm: + drives[sm] = wf + + return Field(drives) + def add(self, other): if not isinstance(other, Field): raise ValueError(f"Cannot add Field and {other.__class__}") @@ -226,7 +268,7 @@ def add(self, other): else: out.drives[spatial_modulation] = waveform - return out + return out.canonicalize() def print_node(self): return "Field" diff --git a/src/bloqade/ir/control/waveform.py b/src/bloqade/ir/control/waveform.py index ca826d5dd..e92c4898a 100644 --- a/src/bloqade/ir/control/waveform.py +++ b/src/bloqade/ir/control/waveform.py @@ -11,7 +11,6 @@ ) from bisect import bisect_left -from dataclasses import InitVar from decimal import Decimal from pydantic.dataclasses import dataclass from beartype.typing import Any, Tuple, Union, List, Callable, Dict @@ -23,6 +22,7 @@ import scipy.integrate as integrate from bloqade.visualization import get_ir_figure from bloqade.visualization import display_ir +from functools import cached_property @beartype @@ -30,7 +30,7 @@ def to_waveform(duration: ScalarType) -> Callable[[Callable], "PythonFn"]: # turn python function into a waveform instruction.""" def waveform_wrapper(fn: Callable) -> "PythonFn": - return PythonFn(fn, duration) + return PythonFn.create(fn, duration) return waveform_wrapper @@ -45,7 +45,7 @@ class Alignment(str, Enum): Right = "right_aligned" -@dataclass +@dataclass(frozen=True) class Waveform: """ Waveform node in the IR. @@ -73,13 +73,6 @@ class Waveform: ``` """ - # def __post_init__(self): - # self._duration = None - - @property - def duration(self): - raise NotImplementedError(f"duration not implemented for {type(self).__name__}") - def __call__(self, clock_s: float, **kwargs) -> float: return float(self.eval_decimal(Decimal(str(clock_s)), **kwargs)) @@ -216,7 +209,7 @@ def print_node(self): raise NotImplementedError -@dataclass +@dataclass(frozen=True) class AlignedWaveform(Waveform): """ @@ -234,11 +227,7 @@ class AlignedWaveform(Waveform): @property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.waveform.duration - return self._duration + return self.waveform.duration def print_node(self): return "AlignedWaveform" @@ -262,7 +251,7 @@ def children(self): return annotated_children -@dataclass() +@dataclass(frozen=True) class Instruction(Waveform): """Instruction node in the IR. @@ -280,12 +269,10 @@ class Instruction(Waveform): ``` """ - @property - def duration(self): - return self._duration + pass -@dataclass(init=False) +@dataclass(init=False, frozen=True) class Linear(Instruction): """ ```bnf @@ -303,12 +290,13 @@ class Linear(Instruction): start: Scalar stop: Scalar - duration: InitVar[Scalar] + duration: Scalar - def __init__(self, start, stop, duration): - self.start = cast(start) - self.stop = cast(stop) - self._duration = cast(duration) + @beartype + def __init__(self, start: ScalarType, stop: ScalarType, duration: ScalarType): + object.__setattr__(self, "start", cast(start)) + object.__setattr__(self, "stop", cast(stop)) + object.__setattr__(self, "duration", cast(duration)) def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: start_value = self.start(**kwargs) @@ -328,7 +316,7 @@ def children(self): return {"start": self.start, "stop": self.stop, "duration": self.duration} -@dataclass(init=False) +@dataclass(init=False, frozen=True) class Constant(Instruction): """ ```bnf @@ -344,11 +332,12 @@ class Constant(Instruction): """ value: Scalar - duration: InitVar[Scalar] + duration: Scalar - def __init__(self, value, duration): - self.value = cast(value) - self._duration = cast(duration) + @beartype + def __init__(self, value: ScalarType, duration: ScalarType): + object.__setattr__(self, "value", cast(value)) + object.__setattr__(self, "duration", cast(duration)) def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: constant_value = self.value(**kwargs) @@ -364,7 +353,7 @@ def children(self): return {"value": self.value, "duration": self.duration} -@dataclass(init=False) +@dataclass(init=False, frozen=True) class Poly(Instruction): """ ```bnf @@ -379,12 +368,13 @@ class Poly(Instruction): """ - coeffs: List[Scalar] - duration: InitVar[Scalar] + coeffs: Tuple[Scalar, ...] + duration: Scalar - def __init__(self, coeffs, duration): - self.coeffs = cast(coeffs) - self._duration = cast(duration) + @beartype + def __init__(self, coeffs: List[ScalarType], duration: ScalarType): + object.__setattr__(self, "coeffs", tuple(map(cast, coeffs))) + object.__setattr__(self, "duration", cast(duration)) def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: # b + x + x^2 + ... + x^n-1 + x^n @@ -417,11 +407,11 @@ def children(self): else: annotated_coeffs["t^" + str(i)] = coeff - annotated_coeffs["duration"] = self._duration + annotated_coeffs["duration"] = self.duration return annotated_coeffs -@dataclass(init=False) +@dataclass(frozen=True) class PythonFn(Instruction): """ @@ -431,15 +421,12 @@ class PythonFn(Instruction): """ fn: Callable # [[float, ...], float] # f(t) -> value + duration: Scalar parameters: List[Union[Variable, AssignedVariable]] # come from ast inspect - duration: InitVar[Scalar] default_param_values: Dict[str, Decimal] - default_arguements: Dict[str, Any] - - def __init__(self, fn: Callable, duration: Any): - self.fn = fn - self._duration = cast(duration) + @staticmethod + def create(fn: Callable, duration: ScalarType) -> "PythonFn": signature = inspect.getfullargspec(fn) if signature.varargs is not None: @@ -450,15 +437,13 @@ def __init__(self, fn: Callable, duration: Any): # get default kwonly first: variables = [] - self.default_param_values = {} - self.default_arguements = {} + default_param_values = {} if signature.kwonlydefaults is not None: for name, value in signature.kwonlydefaults.items(): if isinstance(value, (Real, Decimal)): variables.append(name) - self.default_param_values[name] = Decimal(str(value)) + default_param_values[name] = Decimal(str(value)) else: - # self.default_arguements[name] = value raise ValueError( f"Default value for parameter {name} is not Real or Decimal, " "cannot convert to Variable." @@ -466,7 +451,24 @@ def __init__(self, fn: Callable, duration: Any): variables += signature.args[1:] variables += signature.kwonlyargs - self.parameters = list(map(var, variables)) + + parameters = list(map(var, variables)) + duration = cast(duration) + + return PythonFn(fn, duration, parameters, default_param_values) + + @cached_property + def _hash_value(self) -> int: + return ( + hash(self.__class__) + ^ hash(self.fn) + ^ hash(self.duration) + ^ hash(tuple(self.parameters)) + ^ hash(frozenset(self.default_param_values.items())) + ) + + def __hash__(self) -> int: + return self._hash_value def eval_decimal(self, clock_s: Decimal, **assignments) -> Decimal: new_assignments = {**self.default_param_values, **assignments} @@ -477,7 +479,6 @@ def eval_decimal(self, clock_s: Decimal, **assignments) -> Decimal: kwargs = { param.name: float(param(**new_assignments)) for param in self.parameters } - kwargs = {**self.default_arguements, **kwargs} return Decimal( str( self.fn( @@ -499,7 +500,7 @@ def sample( return Sample(self, interpolation, cast(dt)) -@dataclass +@dataclass(frozen=True) class SmoothingKernel: def __call__(self, value: float) -> float: raise NotImplementedError @@ -515,61 +516,61 @@ class InfiniteSmoothingKernel(SmoothingKernel): pass -@dataclass +@dataclass(frozen=True) class Gaussian(InfiniteSmoothingKernel): def __call__(self, value: float) -> float: return np.exp(-(value**2) / 2) / np.sqrt(2 * np.pi) -@dataclass +@dataclass(frozen=True) class Logistic(InfiniteSmoothingKernel): def __call__(self, value: float) -> float: return np.exp(-(np.logaddexp(0, value) + np.logaddexp(0, -value))) -@dataclass +@dataclass(frozen=True) class Sigmoid(InfiniteSmoothingKernel): def __call__(self, value: float) -> float: return (2 / np.pi) * np.exp(-np.logaddexp(-value, value)) -@dataclass +@dataclass(frozen=True) class Triangle(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return np.maximum(0, 1 - np.abs(value)) -@dataclass +@dataclass(frozen=True) class Uniform(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return np.asarray(np.abs(value) <= 1, dtype=np.float64).squeeze() -@dataclass +@dataclass(frozen=True) class Parabolic(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return (3 / 4) * np.maximum(0, 1 - value**2) -@dataclass +@dataclass(frozen=True) class Biweight(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return (15 / 16) * np.maximum(0, 1 - value**2) ** 2 -@dataclass +@dataclass(frozen=True) class Triweight(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return (35 / 32) * np.maximum(0, 1 - value**2) ** 3 -@dataclass +@dataclass(frozen=True) class Tricube(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return (70 / 81) * np.maximum(0, 1 - np.abs(value) ** 3) ** 3 -@dataclass +@dataclass(frozen=True) class Cosine(FiniteSmoothingKernel): def __call__(self, value: float) -> float: return np.maximum(0, np.pi / 4 * np.cos(np.pi / 2 * value)) @@ -587,7 +588,7 @@ def __call__(self, value: float) -> float: CosineKernel = Cosine() -@dataclass(init=False) +@dataclass(init=False, frozen=True) class Smooth(Waveform): """ ```bnf @@ -624,18 +625,13 @@ def __init__(self, radius, kernel, waveform): else: raise ValueError(f"Invalid kernel: {kernel}") - self.radius = cast(radius) - self.kernel = kernel - self.waveform = waveform - super().__init__() + object.__setattr__(self, "radius", cast(radius)) + object.__setattr__(self, "kernel", kernel) + object.__setattr__(self, "waveform", waveform) @property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.waveform.duration - return self._duration + return self.waveform.duration def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: float_clock_s = float(clock_s) @@ -665,7 +661,7 @@ def integrade(s): raise ValueError(f"Invalid kernel: {self.kernel}") -@dataclass +@dataclass(frozen=True) class Slice(Waveform): """ ``` @@ -676,15 +672,14 @@ class Slice(Waveform): waveform: Waveform interval: Interval - @property + @cached_property def duration(self): - if hasattr(self, "_duration"): - return self._duration + from bloqade.ir.scalar import Slice + + if self.interval.start is None and self.interval.stop is None: + raise ValueError("Interval must have a start or stop value") - start = self.interval.start - stop = self.interval.stop - self._duration = self.waveform.duration[start:stop] - return self._duration + return Slice(self.waveform.duration, self.interval) def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: if clock_s > self.duration(**kwargs): @@ -702,7 +697,7 @@ def children(self): return [self.waveform, self.interval] -@dataclass +@dataclass(frozen=True) class Append(Waveform): """ ```bnf @@ -710,18 +705,15 @@ class Append(Waveform): ``` """ - waveforms: List[Waveform] + waveforms: Tuple[Waveform, ...] - @property + @cached_property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = cast(0.0) + duration = cast(0.0) for waveform in self.waveforms: - self._duration = self._duration + waveform.duration + duration = duration + waveform.duration - return self._duration + return duration def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: append_time = Decimal(0) @@ -742,7 +734,7 @@ def children(self): return self.waveforms -@dataclass +@dataclass(frozen=True) class Negative(Waveform): """ ```bnf @@ -754,12 +746,7 @@ class Negative(Waveform): @property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.waveform.duration - - return self._duration + return self.waveform.duration def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: return -self.waveform.eval_decimal(clock_s, **kwargs) @@ -771,7 +758,7 @@ def children(self): return [self.waveform] -@dataclass(init=False) +@dataclass(init=False, frozen=True) class Scale(Waveform): """ ```bnf @@ -783,17 +770,12 @@ class Scale(Waveform): waveform: Waveform def __init__(self, scalar, waveform: Waveform): - self.scalar = cast(scalar) - self.waveform = waveform + object.__setattr__(self, "scalar", cast(scalar)) + object.__setattr__(self, "waveform", waveform) @property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.waveform.duration - - return self._duration + return self.waveform.duration def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: return self.scalar(**kwargs) * self.waveform.eval_decimal(clock_s, **kwargs) @@ -805,7 +787,7 @@ def children(self): return [self.scalar, self.waveform] -@dataclass +@dataclass(frozen=True) class Add(Waveform): """ ```bnf @@ -816,14 +798,9 @@ class Add(Waveform): left: Waveform right: Waveform - @property + @cached_property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.left.duration.max(self.right.duration) - - return self._duration + return self.left.duration.max(self.right.duration) def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: return self.left(clock_s, **kwargs) + self.right(clock_s, **kwargs) @@ -835,7 +812,7 @@ def children(self): return [self.left, self.right] -@dataclass +@dataclass(frozen=True) class Record(Waveform): """ ```bnf @@ -848,12 +825,7 @@ class Record(Waveform): @property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.waveform.duration - - return self._duration + return self.waveform.duration def eval_decimal(self, clock_s: Decimal, **kwargs) -> Decimal: return self.waveform(clock_s, **kwargs) @@ -870,7 +842,7 @@ class Interpolation(str, Enum): Constant = "constant" -@dataclass +@dataclass(frozen=True) class Sample(Waveform): """ ```bnf @@ -884,12 +856,7 @@ class Sample(Waveform): @property def duration(self): - if hasattr(self, "_duration"): - return self._duration - - self._duration = self.waveform.duration - - return self._duration + return self.waveform.duration def samples(self, **kwargs) -> Tuple[List[Decimal], List[Decimal]]: duration = self.duration(**kwargs) diff --git a/src/bloqade/ir/tree_print.py b/src/bloqade/ir/tree_print.py index e59a7217e..30340192d 100644 --- a/src/bloqade/ir/tree_print.py +++ b/src/bloqade/ir/tree_print.py @@ -161,7 +161,7 @@ def print(self, node, cycle=None): cycle = MAX_TREE_DEPTH # list of children - children = node.children().copy() + children = node.children() node_str = node.print_node() if color_enabled: @@ -184,6 +184,8 @@ def print(self, node, cycle=None): if this_print_annotation: children = list(children.items()) + else: + children = list(children) while not len(children) == 0: child_prefix = self.state.prefix diff --git a/tests/test_field.py b/tests/test_field.py index 94a72c463..2dab3aaba 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -90,9 +90,11 @@ def test_scal_loc(): def test_field_scaled_locations(): Loc = ScaledLocations({1: 1.0, 2: 2.0}) Loc2 = ScaledLocations({3: 1.0, 4: 2.0}) + Loc3 = ScaledLocations({1: 1.0, 2: 2.0, 3: 1.0, 4: 2.0}) f1 = Field({Loc: Linear(start=1.0, stop="x", duration=3.0)}) f2 = Field({Loc: Linear(start=1.0, stop="x", duration=3.0)}) f3 = Field({Loc2: Linear(start=1.0, stop="x", duration=3.0)}) + f4 = Field({Loc3: Linear(start="y", stop="x", duration=3.0)}) # add with non field with pytest.raises(ValueError): @@ -128,8 +130,19 @@ def test_field_scaled_locations(): # add with field diff spat-mod o2 = f1.add(f3) - assert len(o2.drives.keys()) == 2 + assert o2 == Field({Loc3: Linear(start=1.0, stop="x", duration=3.0)}) + # assert len(o2.drives.keys()) == 1 assert f2.print_node() == "Field" # assert type(hash(f1)) == int assert f1.children() == [Drive(k, v) for k, v in f1.drives.items()] + + assert hash(f1) + + o3 = f1.add(f4) + assert o3 == Field( + { + Loc3: Linear(start="y", stop="x", duration=3.0), + Loc: Linear(start=1.0, stop="x", duration=3.0), + } + ) diff --git a/tests/test_hardware_codegen.py b/tests/test_hardware_codegen.py index a45084a2e..1525beef1 100644 --- a/tests/test_hardware_codegen.py +++ b/tests/test_hardware_codegen.py @@ -180,7 +180,7 @@ def my_cos(time): assert my_cos(1) == np.cos(1) - wv = PythonFn(my_cos, duration=1.0) + wv = PythonFn.create(my_cos, duration=1.0) dt = cast(0.1) wf = Sample(wv, Interpolation.Constant, dt) @@ -584,7 +584,7 @@ def my_cos(time): assert my_cos(1) == np.cos(1) - wv = PythonFn(my_cos, duration=1.0) + wv = PythonFn.create(my_cos, duration=1.0) dt = cast(0.1) wf = Sample(wv, Interpolation.Linear, dt) diff --git a/tests/test_waveform.py b/tests/test_waveform.py index 746b95244..7507ebe46 100644 --- a/tests/test_waveform.py +++ b/tests/test_waveform.py @@ -44,6 +44,7 @@ def test_wvfm_linear(): assert wf.print_node() == "Linear" assert wf.eval_decimal(clock_s=Decimal("6.0")) == 0 + assert isinstance(hash(wf), int) assert wf.children() == { "start": cast(1.0), @@ -58,6 +59,7 @@ def test_wvfm_constant(): assert wf.print_node() == "Constant" assert wf.eval_decimal(clock_s=Decimal("6.0")) == 0 assert wf.children() == {"value": cast(1.0), "duration": cast(3.0)} + assert isinstance(hash(wf), int) mystdout = StringIO() p = PP(mystdout) @@ -100,12 +102,12 @@ def my_func3(time, omega, **phi): assert my_func3(3, 2) == 3 with pytest.raises(ValueError): - PythonFn(my_func2, duration=1.0) + PythonFn.create(my_func2, duration=1.0) with pytest.raises(ValueError): - PythonFn(my_func3, duration=1.0) + PythonFn.create(my_func3, duration=1.0) - wf = PythonFn(my_func, duration=1.0) + wf = PythonFn.create(my_func, duration=1.0) awf = annot_my_func assert wf.eval_decimal(Decimal("0.56"), omega=1, amplitude=4) == awf.eval_decimal( @@ -122,6 +124,7 @@ def my_func3(time, omega, **phi): "amplitude": cast("amplitude"), } assert wf.duration == cast(1.0) + assert isinstance(hash(wf), int) mystdout = StringIO() p = PP(mystdout) @@ -148,8 +151,9 @@ def test_wvfm_app(): wf3 = Append([wf, wf2]) assert wf3.print_node() == "Append" - assert wf3.children() == [wf, wf2] + assert wf3.children() == (wf, wf2) assert wf3.eval_decimal(Decimal(10)) == Decimal(0) + assert isinstance(hash(wf), int) mystdout = StringIO() p = PP(mystdout) @@ -168,6 +172,7 @@ def test_wvfm_neg(): assert wf2.print_node() == "Negative" assert wf2.children() == [wf] + assert isinstance(hash(wf), int) assert wf2.eval_decimal(Decimal("0.5")) == Decimal("-1.0") @@ -193,6 +198,7 @@ def test_wvfm_scale(): assert wf2.print_node() == "Scale" assert wf2.children() == [cast(2.0), wf] + assert isinstance(hash(wf), int) assert wf2.eval_decimal(Decimal("0.5")) == Decimal("2.0") @@ -227,6 +233,7 @@ def test_wvfn_add(): assert wf3.print_node() == "+" assert wf3.children() == [wf, wf2] + assert isinstance(hash(wf), int) assert wf3.eval_decimal(Decimal("0")) == Decimal("2.0") assert wf3.eval_decimal(Decimal("2.5")) == Decimal("1.0") @@ -261,6 +268,7 @@ def test_wvfn_rec(): assert re.print_node() == "Record" assert re.children() == {"Waveform": wf, "Variable": cast("tst")} + assert isinstance(hash(wf), int) assert re.eval_decimal(Decimal("0")) == Decimal("1.0") assert re.duration == cast(3.0) @@ -298,6 +306,7 @@ def test_wvfn_poly(): } assert wf.eval_decimal(Decimal("0.5")) == (1) + (2) * 0.5 + (3) * 0.5**2 assert wf.eval_decimal(Decimal("20")) == Decimal("0") + assert isinstance(hash(wf), int) ##----------------------------- @@ -379,6 +388,7 @@ def test_wvfn_slice(): assert wf.eval_decimal(Decimal("0.4")) == 0 assert wf.eval_decimal(Decimal("0.2")) == 2.0 assert wf.children() == [wv, iv] + assert isinstance(hash(wf), int) mystdout = StringIO() p = PP(mystdout) @@ -416,6 +426,7 @@ def test_wvfm_align(): wf = AlignedWaveform(wv, Alignment.Left, cast(0.2)) assert wf.print_node() == "AlignedWaveform" assert wf.children() == {"Waveform": wv, "Alignment": "Left", "Value": cast(0.2)} + assert isinstance(hash(wf), int) wf2 = AlignedWaveform(wv, Alignment.Left, AlignedValue.Right) assert wf2.print_node() == "AlignedWaveform" @@ -451,7 +462,7 @@ def my_cos(time): assert my_cos(1) == np.cos(1) - wv = PythonFn(my_cos, duration=1.0) + wv = PythonFn.create(my_cos, duration=1.0) dt = cast(0.1) wf = Sample(wv, Interpolation.Constant, dt) @@ -460,6 +471,7 @@ def my_cos(time): assert wf.children() == {"Waveform": wv, "sample_step": dt} assert wf.eval_decimal(Decimal(0.05)) == my_cos(0) assert float(wf.eval_decimal(Decimal(0))) == my_cos(0) + assert isinstance(hash(wf), int) wf2 = Sample(wv, Interpolation.Linear, dt) @@ -469,6 +481,7 @@ def my_cos(time): assert float(wf2.eval_decimal(Decimal(0.05))) == float(my_cos(0) + slope * 0.05) assert float(wf2.eval_decimal(Decimal(3))) == 0 assert float(wf2.eval_decimal(Decimal(0))) == my_cos(0) + assert isinstance(hash(wf), int) mystdout = StringIO() p = PP(mystdout) diff --git a/tests/test_waveform_visitor.py b/tests/test_waveform_visitor.py index a421e1054..8aafda6b1 100644 --- a/tests/test_waveform_visitor.py +++ b/tests/test_waveform_visitor.py @@ -32,7 +32,7 @@ wv_scale = Scale(cast(4), wv_linear) wv_add = Add(wv_linear, wv_constant) wv_record = Record(wv_linear, cast("tst")) -wv_python = PythonFn(lambda x: x**2, duration=0.5) +wv_python = PythonFn.create(lambda x: x**2, duration=0.5) wv_sample = Sample(wv_linear, Interpolation("linear"), dt=cast(0.5)) wv_poly = Poly(coeffs=[0.1, 0.2], duration=0.5) wv_smooth = Smooth(radius=cast(0.5), kernel=GaussianKernel, waveform=wv_linear)