Skip to content

Commit

Permalink
ENH: support unevaluated_expression default arguments (#376)
Browse files Browse the repository at this point in the history
* DOC: show ClassVar and default argument use
* ENH: improve local variable names
* ENH: support default arguments for `unevaluated_expression`
* MAINT: write test for `unevaluated_expression` default arguments
  • Loading branch information
redeboer authored Dec 18, 2023
1 parent ef0c5ad commit 5182dbf
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 5 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
"sharey",
"startswith",
"suptitle",
"sympified",
"sympify",
"symplot",
"theano",
Expand Down
50 changes: 50 additions & 0 deletions docs/usage/sympy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,56 @@
"Math(aslatex({e: e.evaluate() for e in [rho_expr, q_expr]}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Class variables and default arguments to instance arguments are also supported:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"from typing import Any, ClassVar\n",
"\n",
"\n",
"@unevaluated_expression\n",
"class FunkyPower(sp.Expr):\n",
" x: Any\n",
" m: int = 1\n",
" default_return: ClassVar[sp.Expr | None] = None\n",
" _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n",
"\n",
" def evaluate(self) -> sp.Expr | None:\n",
" if self.default_return is None:\n",
" return self.x**self.m\n",
" return self.default_return\n",
"\n",
"\n",
"x = sp.Symbol(\"x\")\n",
"exprs = (\n",
" FunkyPower(x),\n",
" FunkyPower(x, 2),\n",
" FunkyPower(x, m=3),\n",
")\n",
"Math(aslatex({e: e.doit() for e in exprs}))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"FunkyPower.default_return = sp.Rational(0.5)\n",
"Math(aslatex({e: e.doit() for e in exprs}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
14 changes: 9 additions & 5 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]:
@functools.wraps(cls.__new__)
@_insert_args_in_signature(attr_names, idx=1)
def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
attr_values, kwargs = _get_attribute_values(attr_names, *args, **kwargs)
attr_values = sp.sympify(attr_values)
expr = sp.Expr.__new__(cls, *attr_values, **kwargs)
for name, value in zip(attr_names, attr_values):
positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs)
sympified_args = sp.sympify(positional_args)
expr = sp.Expr.__new__(cls, *sympified_args, **hints)
for name, value in zip(attr_names, sympified_args):
setattr(expr, name, value)
if evaluate:
return expr.evaluate()
Expand All @@ -147,7 +147,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:


def _get_attribute_values(
attr_names: tuple[str, ...], *args, **kwargs
cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs
) -> tuple[tuple, dict[str, Any]]:
if len(args) == len(attr_names):
return args, kwargs
Expand All @@ -163,6 +163,10 @@ def _get_attribute_values(
if name in kwargs:
attr_values.append(kwargs.pop(name))
remaining_attr_names.pop(0)
elif hasattr(cls, name):
default_value = getattr(cls, name)
attr_values.append(default_value)
remaining_attr_names.pop(0)
if remaining_attr_names:
msg = f"Missing constructor arguments: {', '.join(remaining_attr_names)}"
raise ValueError(msg)
Expand Down
54 changes: 54 additions & 0 deletions tests/sympy/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ def evaluate(self) -> sp.Expr:
assert y_expr.doit() == 5**3


def test_unevaluated_expression_default_argument():
@unevaluated_expression
class FunkyPower(sp.Expr):
x: Any
m: int = 1
default_return: ClassVar[float | None] = None

def evaluate(self) -> sp.Expr:
if self.default_return is None:
return self.x**self.m
return sp.sympify(self.default_return)

x = sp.Symbol("x")
exprs = (
FunkyPower(x),
FunkyPower(x, 2),
FunkyPower(x, m=3),
)
assert exprs[0].doit() == x
assert exprs[1].doit() == x**2
assert exprs[2].doit() == x**3
for expr in exprs:
assert expr.x is x
assert isinstance(expr.m, sp.Integer)
assert expr.default_return is None

half = sp.Rational(1, 2)
FunkyPower.default_return = half
assert exprs[0].doit() == half
assert exprs[1].doit() == half
assert exprs[2].doit() == half
for expr in exprs:
assert expr.x is x
assert isinstance(expr.m, sp.Integer)
assert expr.default_return is half


def test_unevaluated_expression_callable():
@unevaluated_expression(implement_doit=False)
class Squared(sp.Expr):
Expand All @@ -153,3 +190,20 @@ class MySqrt(sp.Expr):
expr = MySqrt(-1)
assert expr.is_commutative
assert expr.is_complex # type: ignore[attr-defined]


def test_unevaluated_expression_default_args():
@unevaluated_expression
class MyExpr(sp.Expr):
x: Any
m: int = 2

def evaluate(self) -> sp.Expr:
return self.x**self.m

expr1 = MyExpr(x=5)
assert str(expr1) == "MyExpr(5, 2)"
assert expr1.doit() == 5**2

expr2 = MyExpr(4, 3)
assert expr2.doit() == 4**3

0 comments on commit 5182dbf

Please sign in to comment.