Skip to content

Commit

Permalink
Add support for diffing expressions with indexed vars in `differentia…
Browse files Browse the repository at this point in the history
…te2c` (#1483)
  • Loading branch information
JCGoran authored Oct 7, 2024
1 parent 1a97ed1 commit 1909448
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
7 changes: 6 additions & 1 deletion python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def _interweave_eqs(F, J):
return code


def make_symbol(var, /):
"""Create SymPy symbol from a variable."""
return sp.Symbol(var, real=True) if isinstance(var, str) else var


def solve_lin_system(
eq_strings,
vars,
Expand Down Expand Up @@ -618,7 +623,7 @@ def differentiate2c(
vars = set(vars)
vars.discard(dependent_var)
# declare all other supplied variables
sympy_vars = {var: sp.symbols(var, real=True) for var in vars}
sympy_vars = {str(var): make_symbol(var) for var in vars}
sympy_vars[dependent_var] = x

# parse string into SymPy equation
Expand Down
14 changes: 12 additions & 2 deletions test/unit/ode/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from nmodl.ode import differentiate2c, integrate2c
from nmodl.ode import differentiate2c, integrate2c, make_symbol
import pytest

import sympy as sp
Expand All @@ -29,7 +29,7 @@ def _equivalent(
"""
lhs = lhs.replace("pow(", "Pow(")
rhs = rhs.replace("pow(", "Pow(")
sympy_vars = {var: sp.symbols(var, real=True) for var in vars}
sympy_vars = {str(var): make_symbol(var) for var in vars}
for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)):
eq_l = sp.sympify(l, locals=sympy_vars)
eq_r = sp.sympify(r, locals=sympy_vars)
Expand Down Expand Up @@ -101,6 +101,16 @@ def test_differentiate2c():
"g",
)

assert _equivalent(
differentiate2c(
"(s[0] + s[1])*(z[0]*z[1]*z[2])*x",
"x",
{sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])},
),
"(s[0] + s[1])*(z[0]*z[1]*z[2])",
{sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])},
)

result = differentiate2c(
"-f(x)",
"x",
Expand Down

0 comments on commit 1909448

Please sign in to comment.