Skip to content

Commit

Permalink
Merge pull request #110 from SciML/DIv6
Browse files Browse the repository at this point in the history
DifferentiationInterface v6: Constant and arguments orders changes
  • Loading branch information
Vaibhavdixit02 authored Oct 1, 2024
2 parents e30f933 + 3038abf commit 1ebca6f
Show file tree
Hide file tree
Showing 8 changed files with 472 additions and 524 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "2.0.4"
version = "2.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -38,11 +38,11 @@ OptimizationReverseDiffExt = "ReverseDiff"
OptimizationZygoteExt = "Zygote"

[compat]
ADTypes = "1.5"
ADTypes = "1.9"
ArrayInterface = "7.6"
DifferentiationInterface = "0.5"
DifferentiationInterface = "0.6.1"
DocStringExtensions = "0.9"
Enzyme = "0.12.12"
Enzyme = "0.13.2"
FastClosures = "0.3"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26"
Expand Down
14 changes: 8 additions & 6 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
function hv_f2_alloc(x, f, p)
dx = Enzyme.make_zero(x)
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(x, dx),
Expand All @@ -58,7 +58,8 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
end

function cons_f2(x, dx, fcons, p, num_cons, i)
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx),
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(num_cons), Const(i))
return nothing
end
Expand All @@ -71,7 +72,7 @@ end

function cons_f2_oop(x, dx, fcons, p, i)
Enzyme.autodiff_deferred(
Enzyme.Reverse, inner_cons_oop, Active, Enzyme.Duplicated(x, dx),
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
Const(fcons), Const(p), Const(i))
return nothing
end
Expand All @@ -83,7 +84,8 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
end

function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
Enzyme.autodiff_deferred(
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
return nothing
end
Expand Down Expand Up @@ -187,7 +189,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if hv == true && f.hv === nothing
function hv!(H, θ, v, p = p)
H .= Enzyme.autodiff(
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
Const(f.f), Const(p)
)[1]
end
Expand Down Expand Up @@ -531,7 +533,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
for i in eachindex(Jaccache)
Enzyme.make_zero!(Jaccache[i])
end
y, Jaccache = Enzyme.autodiff(Enzyme.Forward, f.cons, Duplicated,
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
BatchDuplicated(θ, seeds), Const(p))
if size(y, 1) == 1
return reduce(vcat, Jaccache)
Expand Down
Loading

0 comments on commit 1ebca6f

Please sign in to comment.