Skip to content

Commit 24371c2

Browse files
authored
Merge pull request #130 from SymbolicML/fix-zygote-mutation
fix: mutation error for Zygote
2 parents 4c9508f + cd69d8d commit 24371c2

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/EvaluationHelpers.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
module EvaluationHelpersModule
22

3+
using ChainRulesCore: @non_differentiable
4+
35
import Base: adjoint
46
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
57
import ..NodeModule: AbstractExpressionNode
68
import ..EvaluateModule: eval_tree_array
79
import ..EvaluateDerivativeModule: eval_grad_tree_array
810

11+
function _set_nan!(out)
12+
out .= convert(eltype(out), NaN)
13+
return nothing
14+
end
15+
@non_differentiable _set_nan!(out)
16+
917
# Evaluation:
1018
"""
1119
(tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...)
@@ -27,7 +35,7 @@ and triplets of operations for lower memory usage.
2735
"""
2836
function (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...)
2937
out, did_finish = eval_tree_array(tree, X, operators; kws...)
30-
!did_finish && (out .= convert(eltype(out), NaN))
38+
!did_finish && _set_nan!(out)
3139
return out
3240
end
3341
"""
@@ -56,7 +64,7 @@ function _grad_evaluator(
5664
tree::AbstractExpressionNode, X, operators::OperatorEnum; variable=Val(true), kws...
5765
)
5866
_, grad, did_complete = eval_grad_tree_array(tree, X, operators; variable, kws...)
59-
!did_complete && (grad .= convert(eltype(grad), NaN))
67+
!did_complete && _set_nan!(grad)
6068
return grad
6169
end
6270
function _grad_evaluator(

0 commit comments

Comments
 (0)