Skip to content

Commit 2aad117

Browse files
authored
Disable thunks for 2nd order AD (#683)
* Allow disabling thunking * Bump to 1.25.1 * Simplify * Unconditionally disable thunks for 2nd order AD * Add comment * Fixes
1 parent 2f2c941 commit 2aad117

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.25.0"
3+
version = "1.25.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/thunks.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Disable thunks for 2nd order AD.
2+
_usethunks() = true
3+
rrule(::typeof(_usethunks)) = false, Returns((NoTangent(),))
4+
15
abstract type AbstractThunk <: AbstractTangent end
26

37
struct MutateThunkException <: Exception end
@@ -141,7 +145,11 @@ macro thunk(body)
141145
# Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined.
142146
# so we get useful stack traces if it errors.
143147
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
144-
return :(Thunk($(esc(func))))
148+
return quote
149+
_usethunks() ?
150+
Thunk($(esc(func))) :
151+
$(esc(body))
152+
end
145153
end
146154

147155
"""
@@ -233,6 +241,12 @@ and destroy its inplacability.
233241
struct InplaceableThunk{T<:Thunk,F} <: AbstractThunk
234242
add!::F
235243
val::T
244+
245+
function InplaceableThunk(add!::F, val::T) where {F, T}
246+
_usethunks() ?
247+
new{T, F}(add!, val) :
248+
val
249+
end
236250
end
237251

238252
unthunk(x::InplaceableThunk) = unthunk(x.val)

0 commit comments

Comments
 (0)