Skip to content

Commit

Permalink
Merge pull request #1041 from vyudu/DSL
Browse files Browse the repository at this point in the history
expanding equations when passed to `@reaction_network`
  • Loading branch information
vyudu authored Sep 9, 2024
2 parents 31b0f99 + 672225a commit 738a7d8
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Catalyst"
uuid = "479239e8-5488-4da2-87a7-35f2df7eef83"
version = "14.4"
version = "14.4.0"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down
31 changes: 21 additions & 10 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))

# Get macro options.
if length(unique(arg.args[1] for arg in option_lines)) < length(option_lines)
error("Some options where given multiple times.")
error("Some options were given multiple times.")
end
options = Dict(map(arg -> Symbol(String(arg.args[1])[2:end]) => arg,
option_lines))
Expand All @@ -315,12 +315,12 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
parameters_declared = extract_syms(options, :parameters)
variables_declared = extract_syms(options, :variables)

# Reads more options.
# Reads equations.
vars_extracted, add_default_diff, equations = read_equations_options(
options, variables_declared)
variables = vcat(variables_declared, vars_extracted)

# handle independent variables
# Handle independent variables
if haskey(options, :ivs)
ivs = Tuple(extract_syms(options, :ivs))
ivexpr = copy(options[:ivs])
Expand All @@ -339,14 +339,16 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
combinatoric_ratelaws = true
end

# Reads more options.
# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs)

# Collect species and parameters, including ones inferred from the reactions.
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
variables)))
species_extracted, parameters_extracted = extract_species_and_parameters!(reactions,
declared_syms)
species_extracted, parameters_extracted = extract_species_and_parameters!(
reactions, declared_syms)

species = vcat(species_declared, species_extracted)
parameters = vcat(parameters_declared, parameters_extracted)

Expand Down Expand Up @@ -376,9 +378,11 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
push!(rxexprs.args, get_rxexprs(reaction))
end
for equation in equations
equation = escape_equation_RHS!(equation)
push!(rxexprs.args, equation)
end

# Output code corresponding to the reaction system.
quote
$ivexpr
$ps
Expand Down Expand Up @@ -572,7 +576,7 @@ function get_rxexprs(rxstruct)
subs_stoich_init = deepcopy(subs_init)
prod_init = isempty(rxstruct.products) ? nothing : :([])
prod_stoich_init = deepcopy(prod_init)
reaction_func = :(Reaction($(recursive_expand_functions!(rxstruct.rate)), $subs_init,
reaction_func = :(Reaction($(recursive_escape_functions!(rxstruct.rate)), $subs_init,
$prod_init, $subs_stoich_init, $prod_stoich_init,
metadata = $(rxstruct.metadata)))
for sub in rxstruct.substrates
Expand Down Expand Up @@ -904,17 +908,24 @@ end

### Generic Expression Manipulation ###

# Recursively traverses an expression and replaces special function call like "hill(...)" with the actual corresponding expression.
function recursive_expand_functions!(expr::ExprValues)
# Recursively traverses an expression and escapes all the user-defined functions. Special function calls like "hill(...)" are not expanded.
function recursive_escape_functions!(expr::ExprValues)
(typeof(expr) != Expr) && (return expr)
foreach(i -> expr.args[i] = recursive_expand_functions!(expr.args[i]),
foreach(i -> expr.args[i] = recursive_escape_functions!(expr.args[i]),
1:length(expr.args))
if expr.head == :call
!isdefined(Catalyst, expr.args[1]) && (expr.args[1] = esc(expr.args[1]))
end
expr
end

# Recursively escape functions in the right-hand-side of an equation written using user-defined functions. Special function calls like "hill(...)" are not expanded.
function escape_equation_RHS!(eqexpr::Expr)
rhs = recursive_escape_functions!(eqexpr.args[3])
eqexpr.args[3] = rhs
eqexpr
end

# Returns the length of a expression tuple, or 1 if it is not an expression tuple (probably a Symbol/Numerical).
function tup_leng(ex::ExprValues)
(typeof(ex) == Expr && ex.head == :tuple) && (return length(ex.args))
Expand Down
74 changes: 74 additions & 0 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ seed = rand(rng, 1:100)

# Sets the default `t` to use.
t = default_t()
D = default_time_deriv()

### Tests `@parameters`, `@species`, and `@variables` Options ###

Expand Down Expand Up @@ -952,3 +953,76 @@ let
@unpack k1, A = rn3
@test isequal(rl, k1*A^2)
end

# Test whether user-defined functions are properly expanded in equations.
let
f(A, t) = 2*A*t

# Test user-defined function
rn = @reaction_network begin
@equations D(A) ~ f(A, t)
end
@test length(equations(rn)) == 1
@test equations(rn)[1] isa Equation
@species A(t)
@test isequal(equations(rn)[1], D(A) ~ 2*A*t)


# Test whether expansion happens properly for unregistered/registered functions.
hill_unregistered(A, v, K, n) = v*(A^n) / (A^n + K^n)
rn2 = @reaction_network begin
@parameters v K n
@equations D(A) ~ hill_unregistered(A, v, K, n)
end
@test length(equations(rn2)) == 1
@test equations(rn2)[1] isa Equation
@parameters v K n
@test isequal(equations(rn2)[1], D(A) ~ v*(A^n) / (A^n + K^n))

hill2(A, v, K, n) = v*(A^n) / (A^n + K^n)
@register_symbolic hill2(A, v, K, n)
# Registered symbolic function should not expand.
rn2r = @reaction_network begin
@parameters v K n
@equations D(A) ~ hill2(A, v, K, n)
end
@test length(equations(rn2r)) == 1
@test equations(rn2r)[1] isa Equation
@parameters v K n
@test isequal(equations(rn2r)[1], D(A) ~ hill2(A, v, K, n))


rn3 = @reaction_network begin
@species Iapp(t)
@equations begin
D(A) ~ Iapp
Iapp ~ f(A,t)
end
end
@test length(equations(rn3)) == 2
@test equations(rn3)[1] isa Equation
@test equations(rn3)[2] isa Equation
@variables Iapp(t)
@test isequal(equations(rn3)[1], D(A) ~ Iapp)
@test isequal(equations(rn3)[2], Iapp ~ 2*A*t)

# Test whether the DSL and symbolic ways of creating the network generate the same system
@species Iapp(t) A(t)
eq = [D(A) ~ Iapp, Iapp ~ f(A, t)]
@named rn3_sym = ReactionSystem(eq, t)
rn3_sym = complete(rn3_sym)
@test isequivalent(rn3, rn3_sym)


# Test more complicated expression involving both registered function and a user-defined function.
g(A, K, n) = A^n + K^n
rn4 = @reaction_network begin
@parameters v K n
@equations D(A) ~ hill(A, v, K, n)*g(A, K, n)
end
@test length(equations(rn4)) == 1
@test equations(rn4)[1] isa Equation
@parameters v n
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
end

2 changes: 1 addition & 1 deletion test/reactionsystem_core/coupled_equation_crn_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1042,4 +1042,4 @@ let
u0 = [S1 => 1.0, S2 => 2.0, V1 => 0.1]
ps = [p1 => 2.0, p2 => 3.0]
@test_throws Exception ODEProblem(rs, u0, (0.0, 1.0), ps; structural_simplify = true)
end
end

0 comments on commit 738a7d8

Please sign in to comment.