Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: support unevaluated_expression default arguments #376

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
@@ -182,6 +182,7 @@
"sharey",
"startswith",
"suptitle",
"sympified",
"sympify",
"symplot",
"theano",
50 changes: 50 additions & 0 deletions docs/usage/sympy.ipynb
Original file line number Diff line number Diff line change
@@ -162,6 +162,56 @@
"Math(aslatex({e: e.evaluate() for e in [rho_expr, q_expr]}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Class variables and default arguments to instance arguments are also supported:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"from typing import Any, ClassVar\n",
"\n",
"\n",
"@unevaluated_expression\n",
"class FunkyPower(sp.Expr):\n",
" x: Any\n",
" m: int = 1\n",
" default_return: ClassVar[sp.Expr | None] = None\n",
" _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n",
"\n",
" def evaluate(self) -> sp.Expr | None:\n",
" if self.default_return is None:\n",
" return self.x**self.m\n",
" return self.default_return\n",
"\n",
"\n",
"x = sp.Symbol(\"x\")\n",
"exprs = (\n",
" FunkyPower(x),\n",
" FunkyPower(x, 2),\n",
" FunkyPower(x, m=3),\n",
")\n",
"Math(aslatex({e: e.doit() for e in exprs}))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"FunkyPower.default_return = sp.Rational(0.5)\n",
"Math(aslatex({e: e.doit() for e in exprs}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
14 changes: 9 additions & 5 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
@@ -133,10 +133,10 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]:
@functools.wraps(cls.__new__)
@_insert_args_in_signature(attr_names, idx=1)
def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
attr_values, kwargs = _get_attribute_values(attr_names, *args, **kwargs)
attr_values = sp.sympify(attr_values)
expr = sp.Expr.__new__(cls, *attr_values, **kwargs)
for name, value in zip(attr_names, attr_values):
positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs)
sympified_args = sp.sympify(positional_args)
expr = sp.Expr.__new__(cls, *sympified_args, **hints)
for name, value in zip(attr_names, sympified_args):
setattr(expr, name, value)
if evaluate:
return expr.evaluate()
@@ -147,7 +147,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:


def _get_attribute_values(
attr_names: tuple[str, ...], *args, **kwargs
cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs
) -> tuple[tuple, dict[str, Any]]:
if len(args) == len(attr_names):
return args, kwargs
@@ -163,6 +163,10 @@ def _get_attribute_values(
if name in kwargs:
attr_values.append(kwargs.pop(name))
remaining_attr_names.pop(0)
elif hasattr(cls, name):
default_value = getattr(cls, name)
attr_values.append(default_value)
remaining_attr_names.pop(0)
if remaining_attr_names:
msg = f"Missing constructor arguments: {', '.join(remaining_attr_names)}"
raise ValueError(msg)
54 changes: 54 additions & 0 deletions tests/sympy/test_decorator.py
Original file line number Diff line number Diff line change
@@ -134,6 +134,43 @@ def evaluate(self) -> sp.Expr:
assert y_expr.doit() == 5**3


def test_unevaluated_expression_default_argument():
@unevaluated_expression
class FunkyPower(sp.Expr):
x: Any
m: int = 1
default_return: ClassVar[float | None] = None

def evaluate(self) -> sp.Expr:
if self.default_return is None:
return self.x**self.m
return sp.sympify(self.default_return)

x = sp.Symbol("x")
exprs = (
FunkyPower(x),
FunkyPower(x, 2),
FunkyPower(x, m=3),
)
assert exprs[0].doit() == x
assert exprs[1].doit() == x**2
assert exprs[2].doit() == x**3
for expr in exprs:
assert expr.x is x
assert isinstance(expr.m, sp.Integer)
assert expr.default_return is None

half = sp.Rational(1, 2)
FunkyPower.default_return = half
assert exprs[0].doit() == half
assert exprs[1].doit() == half
assert exprs[2].doit() == half
for expr in exprs:
assert expr.x is x
assert isinstance(expr.m, sp.Integer)
assert expr.default_return is half


def test_unevaluated_expression_callable():
@unevaluated_expression(implement_doit=False)
class Squared(sp.Expr):
@@ -153,3 +190,20 @@ class MySqrt(sp.Expr):
expr = MySqrt(-1)
assert expr.is_commutative
assert expr.is_complex # type: ignore[attr-defined]


def test_unevaluated_expression_default_args():
@unevaluated_expression
class MyExpr(sp.Expr):
x: Any
m: int = 2

def evaluate(self) -> sp.Expr:
return self.x**self.m

expr1 = MyExpr(x=5)
assert str(expr1) == "MyExpr(5, 2)"
assert expr1.doit() == 5**2

expr2 = MyExpr(4, 3)
assert expr2.doit() == 4**3