From 1909448515434bb01af7794614e45836e276f36d Mon Sep 17 00:00:00 2001 From: JCGoran Date: Mon, 7 Oct 2024 09:37:37 +0200 Subject: [PATCH] Add support for diffing expressions with indexed vars in `differentiate2c` (#1483) --- python/nmodl/ode.py | 7 ++++++- test/unit/ode/test_ode.py | 14 ++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index e40cb47c6..cd6b2b27a 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -247,6 +247,11 @@ def _interweave_eqs(F, J): return code +def make_symbol(var, /): + """Create SymPy symbol from a variable.""" + return sp.Symbol(var, real=True) if isinstance(var, str) else var + + def solve_lin_system( eq_strings, vars, @@ -618,7 +623,7 @@ def differentiate2c( vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = {str(var): make_symbol(var) for var in vars} sympy_vars[dependent_var] = x # parse string into SymPy equation diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index df5f6c4f0..6eae70699 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from nmodl.ode import differentiate2c, integrate2c +from nmodl.ode import differentiate2c, integrate2c, make_symbol import pytest import sympy as sp @@ -29,7 +29,7 @@ def _equivalent( """ lhs = lhs.replace("pow(", "Pow(") rhs = rhs.replace("pow(", "Pow(") - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = {str(var): make_symbol(var) for var in vars} for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)): eq_l = sp.sympify(l, locals=sympy_vars) eq_r = sp.sympify(r, locals=sympy_vars) @@ -101,6 +101,16 @@ def test_differentiate2c(): "g", ) + assert _equivalent( + differentiate2c( + "(s[0] + s[1])*(z[0]*z[1]*z[2])*x", + "x", + {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, + ), + "(s[0] + s[1])*(z[0]*z[1]*z[2])", + {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, + ) + result = differentiate2c( "-f(x)", "x",