Skip to content

Commit 3279fbc

Browse files
authored
Merge pull request #131 from SymbolicML/fix-zygote-mutation
fix: move non_differentiable to special module for JET masking
2 parents 24371c2 + 0e66242 commit 3279fbc

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/EvaluationHelpers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import ..NodeModule: AbstractExpressionNode
88
import ..EvaluateModule: eval_tree_array
99
import ..EvaluateDerivativeModule: eval_grad_tree_array
1010

11+
# Needs to be special function so we can declare it non-differentiable to Zygote
1112
function _set_nan!(out)
1213
out .= convert(eltype(out), NaN)
1314
return nothing
1415
end
15-
@non_differentiable _set_nan!(out)
1616

1717
# Evaluation:
1818
"""

src/NonDifferentiableDeclarations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import ..NodeModule: AbstractExpressionNode, AbstractNode
66
import ..NodeUtilsModule: tree_mapreduce
77
import ..ExpressionModule:
88
AbstractExpression, get_operators, get_variable_names, _validate_input
9+
import ..EvaluationHelpersModule: _set_nan!
910

1011
#! format: off
1112
@non_differentiable tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type)
1213
@non_differentiable tree_mapreduce(f::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type)
1314
@non_differentiable get_operators(ex::Union{AbstractExpression,AbstractExpressionNode}, operators::Union{AbstractOperatorEnum,Nothing})
1415
@non_differentiable get_variable_names(ex::AbstractExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing})
1516
@non_differentiable _validate_input(ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing})
17+
@non_differentiable _set_nan!(::Any)
1618
#! format: on
1719

1820
end

0 commit comments

Comments
 (0)