Skip to content

Commit

Permalink
Get rid of constant_function in AutoEnzyme (#401)
Browse files Browse the repository at this point in the history
* Fully remove `constant_function` for `AutoEnzyme` following ADTypes yanking

* Typo
  • Loading branch information
gdalle authored Aug 2, 2024
1 parent 0c625e4 commit d0e26ef
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
4 changes: 2 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.5.10"
version = "0.5.11"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -44,7 +44,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
ADTypes = "1.6.1"
ADTypes = "1.6.2"
ChainRulesCore = "1.23.0"
Compat = "3.46,4.2"
Diffractor = "=0.2.6"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
struct AutoDeferredEnzyme{M,constant_function} <: ADTypes.AbstractADType
struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType
mode::M
end

ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode))

function DI.nested(backend::AutoEnzyme{M,constant_function}) where {M,constant_function}
return AutoDeferredEnzyme{M,constant_function}(backend.mode)
function DI.nested(backend::AutoEnzyme{M}) where {M}
return AutoDeferredEnzyme{M}(backend.mode)
end

const AnyAutoEnzyme{M,constant_function} = Union{
AutoEnzyme{M,constant_function},AutoDeferredEnzyme{M,constant_function}
}
const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}}

# forward mode if possible
forward_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode
Expand All @@ -33,20 +31,3 @@ function DI.basis(::AnyAutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where
end

get_f_and_df(f, ::AnyAutoEnzyme) = Const(f)

#=
# commented out until Enzyme errors when non-duplicated data is written to
function get_f_and_df(f, backend::AnyAutoEnzyme{M,false}) where {M}
mode = isnothing(backend.mode) ? Reverse : backend.mode
A = guess_activity(typeof(f), mode)
if A <: Const || A <: Active
return Const(f)
elseif A <: Duplicated || A <: DuplicatedNoNeed || A <: MixedDuplicated
df = make_zero(f)
return Duplicated(f, df)
else
error("Unexpected activity guessed for the function `f`.")
end
end
=#
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDiff", "Lux", "L
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"

[compat]
ADTypes = "1.0.0"
ADTypes = "1.6.2"
Chairmarks = "1.2.1"
Compat = "3.46,4.2"
ComponentArrays = "0.15"
Expand Down

0 comments on commit d0e26ef

Please sign in to comment.