Skip to content

Commit

Permalink
Use an older version of DI.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Sep 11, 2024
1 parent ad56716 commit bf85245
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[compat]
ADTypes = "1.2.1"
DifferentiationInterface = "0.6.0"
DifferentiationInterface = "0.5.17"
ForwardDiff = "0.9.0, 0.10.0"
NLPModels = "0.18, 0.19, 0.20, 0.21"
Requires = "1"
Expand Down
21 changes: 14 additions & 7 deletions src/di.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ for (ADGradient, fbackend) in ((:EnzymeADGradient, :AutoEnzyme),
end

function gradient(b::$ADGradient, f, x)
g = DifferentiationInterface.gradient(f, b.extras, b.backend, x)
g = DifferentiationInterface.gradient(f, b.backend, x, b.extras)
# g = DifferentiationInterface.gradient(f, b.extras, b.backend, x)
return g
end

function gradient!(b::$ADGradient, g, f, x)
DifferentiationInterface.gradient!(f, g, b.extras, b.backend, x)
DifferentiationInterface.gradient!(f, g, b.backend, x, b.extras)
# DifferentiationInterface.gradient!(f, g, b.extras, b.backend, x)
return g
end

Expand Down Expand Up @@ -57,7 +59,8 @@ for (ADJprod, fbackend) in ((:EnzymeADJprod, :AutoEnzyme),
end

function Jprod!(b::$ADJprod, Jv, f, x, v, ::Val)
DifferentiationInterface.pushforward!(f, Jv, b.extras, b.backend, x, v)
DifferentiationInterface.pushforward!(f, Jv, b.backend, x, v, b.extras)
# DifferentiationInterface.pushforward!(f, Jv, b.extras, b.backend, x, v)
return Jv
end

Expand Down Expand Up @@ -88,7 +91,8 @@ for (ADJtprod, fbackend) in ((:EnzymeADJtprod, :AutoEnzyme),
end

function Jtprod!(b::$ADJtprod, Jtv, f, x, v, ::Val)
DifferentiationInterface.pullback!(f, Jtv, b.extras, b.backend, x, v)
DifferentiationInterface.pullback!(f, Jtv, b.backend, x, v, b.extras)
# DifferentiationInterface.pullback!(f, Jtv, b.extras, b.backend, x, v)
return Jtv
end

Expand Down Expand Up @@ -119,7 +123,9 @@ for (ADJacobian, fbackend) in ((:EnzymeADJacobian, :AutoEnzyme),
end

function jacobian(b::$ADJacobian, f, x)
return DifferentiationInterface.jacobian(f, b.extras, b.backend, x)
J = DifferentiationInterface.jacobian(f, b.backend, x, b.extras)
# J = DifferentiationInterface.jacobian(f, b.extras, b.backend, x)
return J
end

end
Expand Down Expand Up @@ -150,14 +156,15 @@ end
# end
#
# function Hessian(b::$ADHessian, f, x)
# return DifferentiationInterface.hessian(f, b.extras, b.backend, x)
# H = DifferentiationInterface.hessian(f, b.extras, b.backend, x)
# return H
# end
#
# end
# end

# for (ADHvprod, fbackend) in ((:EnzymeADHvprod, :AutoEnzyme),
# (:ZygoteADHvprod, :AutoZygote))
# (:ZygoteADHvprod, :AutoZygote))
# @eval begin
#
# struct $ADHvprod{B, E} <: ADBackend
Expand Down

0 comments on commit bf85245

Please sign in to comment.