Skip to content

Commit

Permalink
Avoid closure in _f
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jul 13, 2023
1 parent 162075c commit 8a9f7b6
Showing 1 changed file with 34 additions and 39 deletions.
73 changes: 34 additions & 39 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,34 @@ isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)
function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p,
num_cons = 0)
_f = (θ, y, args...) -> (y .= first(f.f(θ, p, args...)); return nothing)
_f = (f, θ, args...) -> first(f(θ, p, args...))

Check warning on line 12 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L12

Added line #L12 was not covered by tests

if f.grad === nothing
function grad(res, θ, args...)
res .= zero(eltype(res))
Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res),
Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...)
Enzyme.autodiff(Enzyme.Reverse, _f, f.f, Enzyme.Duplicated(θ, res), args...)

Check warning on line 17 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L16-L17

Added lines #L16 - L17 were not covered by tests
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
end

if f.hess === nothing
function g(θ, bθ, y, by, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, bθ), Enzyme.DuplicatedNoNeed(y, by), args...)
function g(θ, bθ, _f, f, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse, _f, f, Enzyme.Duplicated(θ, bθ), args...)
return nothing

Check warning on line 26 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L24-L26

Added lines #L24 - L26 were not covered by tests
end
function hess(res, θ, args...)
y = Vector{Float64}(undef, 1)

vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

= zeros(length(θ))
by = ones(1)
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(y),
Const(by),
Const(_f),
f.f,
args...)

for i in eachindex(θ)
Expand All @@ -52,16 +48,18 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
end

if f.hv === nothing
function f2(x, v, _f, f, args...)::Float64
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,

Check warning on line 53 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L51-L53

Added lines #L51 - L53 were not covered by tests
f,
Enzyme.Duplicated(x, dx),
args...)
Float64(dot(dx, v))

Check warning on line 57 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L57

Added line #L57 was not covered by tests
end
hv = function (H, θ, v, args...)
function f2(x, v, args...)::Float64
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,
Enzyme.Duplicated(x, dx),
Enzyme.DuplicatedNoNeed(zeros(1), ones(1)),
args...)
Float64(dot(dx, v))
end
H .= Enzyme.gradient(Enzyme.Forward, x -> f2(x, v, args...), θ)
global f, _f
H .= zero(eltype(H))
Enzyme.autodiff(Enzyme.Forward, f2, Duplicated(θ, H), v, _f, f.f, args...)

Check warning on line 62 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L60-L62

Added lines #L60 - L62 were not covered by tests
end
else
hv = f.hv
Expand Down Expand Up @@ -119,39 +117,35 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
cache::Optimization.ReInitCache,
adtype::AutoEnzyme,
num_cons = 0)
_f = (θ, y, args...) -> (y .= first(f.f(θ, cache.p, args...)); return nothing)
_f = (f, θ, args...) -> first(f(θ, cache.p, args...))

Check warning on line 120 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L120

Added line #L120 was not covered by tests

if f.grad === nothing
function grad(res, θ, args...)
res .= zero(eltype(res))
Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res),
Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...)
Enzyme.autodiff(Enzyme.Reverse, _f, f.f, Enzyme.Duplicated(θ, res), args...)

Check warning on line 125 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)

Check warning on line 128 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L128

Added line #L128 was not covered by tests
end

if f.hess === nothing
function g(θ, bθ, y, by, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, bθ),
Enzyme.DuplicatedNoNeed(y, by), args...)
function g(θ, bθ, _f, f, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse, _f, f, Enzyme.Duplicated(θ, bθ),

Check warning on line 133 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L132-L133

Added lines #L132 - L133 were not covered by tests
args...)
return nothing

Check warning on line 135 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L135

Added line #L135 was not covered by tests
end
function hess(res, θ, args...)
y = Vector{Float64}(undef, 1)

vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

= zeros(length(θ))
by = ones(1)
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(y),
Const(by),
Const(_f),
f.f,
args...)

for i in eachindex(θ)
Expand All @@ -163,16 +157,17 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
end

if f.hv === nothing
function f2(x, v, _f, f, args...)::Float64
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,

Check warning on line 162 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L160-L162

Added lines #L160 - L162 were not covered by tests
f,
Enzyme.Duplicated(x, dx),
args...)
Float64(dot(dx, v))

Check warning on line 166 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L166

Added line #L166 was not covered by tests
end
hv = function (H, θ, v, args...)
function f2(x, v, args...)::Float64
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse, _f,
Enzyme.Duplicated(x, dx),
Enzyme.DuplicatedNoNeed(zeros(1), ones(1)),
args...)
Float64(dot(dx, v))
end
H .= Enzyme.gradient(Enzyme.Forward, x -> f2(x, v, args...), θ)
H .= zero(eltype(H))
Enzyme.autodiff(Enzyme.Forward, f2, Duplicated(θ, H), v, Const(_f), f.f, args...)

Check warning on line 170 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
end
else
hv = f.hv
Expand Down

0 comments on commit 8a9f7b6

Please sign in to comment.