Skip to content

Commit

Permalink
Fix deferred any active return (#1909)
Browse files Browse the repository at this point in the history
* Fix deferred any active return

* fix

* fix
  • Loading branch information
wsmoses authored Sep 28, 2024
1 parent 7c4a31a commit a5c6fee
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 107 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.3"
EnzymeCore = "0.8.4"
Enzyme_jll = "0.0.151"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27"
LLVM = "6.1, 7, 8, 9"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.8.3"
version = "0.8.4"

[compat]
Adapt = "3, 4"
Expand Down
24 changes: 12 additions & 12 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,21 +289,21 @@ Reverse mode differentiation.
- `Width`: Batch Size (0 if to be automatically derived)
- `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived).
"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end
const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}()
const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}()
@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}()
@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}()
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end
const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false, false, false}()
const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false, false, false}()
@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()

@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}()
@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}()
@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, true, ShadowInit}()
@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, Holomorphic, false, ShadowInit}()

@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()

@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()


"""
Expand Down
Loading

0 comments on commit a5c6fee

Please sign in to comment.