1
1
module EvaluationHelpersModule
2
2
3
+ using ChainRulesCore: @non_differentiable
4
+
3
5
import Base: adjoint
4
6
import .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
5
7
import .. NodeModule: AbstractExpressionNode
6
8
import .. EvaluateModule: eval_tree_array
7
9
import .. EvaluateDerivativeModule: eval_grad_tree_array
8
10
11
+ function _set_nan! (out)
12
+ out .= convert (eltype (out), NaN )
13
+ return nothing
14
+ end
15
+ @non_differentiable _set_nan! (out)
16
+
9
17
# Evaluation:
10
18
"""
11
19
(tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...)
@@ -27,7 +35,7 @@ and triplets of operations for lower memory usage.
27
35
"""
28
36
function (tree:: AbstractExpressionNode )(X, operators:: OperatorEnum ; kws... )
29
37
out, did_finish = eval_tree_array (tree, X, operators; kws... )
30
- ! did_finish && (out . = convert ( eltype (out), NaN ) )
38
+ ! did_finish && _set_nan! (out)
31
39
return out
32
40
end
33
41
"""
@@ -56,7 +64,7 @@ function _grad_evaluator(
56
64
tree:: AbstractExpressionNode , X, operators:: OperatorEnum ; variable= Val (true ), kws...
57
65
)
58
66
_, grad, did_complete = eval_grad_tree_array (tree, X, operators; variable, kws... )
59
- ! did_complete && (grad . = convert ( eltype (grad), NaN ) )
67
+ ! did_complete && _set_nan! (grad)
60
68
return grad
61
69
end
62
70
function _grad_evaluator (
0 commit comments