Skip to content

Commit

Permalink
Multi arg fwd gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 9, 2024
1 parent 4d4c546 commit 9ed8992
Show file tree
Hide file tree
Showing 3 changed files with 786 additions and 556 deletions.
327 changes: 230 additions & 97 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1794,16 +1794,29 @@ end
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)

function create_shadows(::Nothing, x)
return (onehot(x),)
end

function create_shadows(::Val{1}, x)
return (onehot(x),)
end

function create_shadows(::Val{chunk}, x) where {chunk}
return (chunkedonehot(x, Val(chunk)),)
@generated function create_shadows(chunk::ChunkTy, x, args::Vararg{Any,N}) where {ChunkTy, N}
if N == 0
if ChunkTy == Nothing
return quote
Base.@_inline_meta
(onehot(x),)
end
elseif ChunkTy == Val{1}
return quote
Base.@_inline_meta
(onehot(x),)
end
else
return quote
Base.@_inline_meta
(chunkedonehot(x, Val(chunk)),)
end
end
else
return quote
error("Unsupported create_shadows of multiple arguments")
end
end
end

struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N}
Expand Down Expand Up @@ -1890,7 +1903,7 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1))
(derivs = ([3.0, 2.0],), val = 6.0)
```
For functions which return an AbstractArray or scalar, this function will return an AbstracttArray
For functions which return an AbstractArray or scalar, this function will return an AbstractArray
whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made
about the type of the AbstractArray returned by this function (which may or may not be the same
as the input AbstractArray if provided).
Expand All @@ -1905,119 +1918,239 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0])
# output
([3.0 2.0 0.0; 0.0 1.0 1.0],)
```
This function supports multiple arguments and computed the jacobian of each
```jldoctest gradfwd2
mul(x, y) = x[1]*y[2] + x[2]*y[1]
gradient(Forward, f, [2.0, 3.0], [2.7, 3.1])
# output
([3.1, 2.7],[3.0, 2.0])
```
This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map.
```jldoctest gradfwd2
gradient(Forward, f, [2.0, 3.0], Const([2.7, 3.1]))
# output
([3.1, 2.7], nothing)
```
```jldoctest gradfwd2
gradient(Forward, f, Const([2.0, 3.0]), [2.7, 3.1])
# output
([3.1, 2.7], nothing)
```
"""
@inline function gradient(
@generated function gradient(
fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity},
f,
x;
f::F,
x::ty_0,
args::Vararg{Any,N};
chunk::CS = nothing,
shadows = create_shadows(chunk, x),
) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS}
shadows::ST = create_shadows(chunk, x, args...),
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST}

syms = Union{Symbol,Expr}[:x]
shads = Union{Symbol,Expr}[:(shadows[1])]
tys = Type[ty_0]
for i in 1:N
push!(syms, :(args[$i]))
push!(tys, args[i])
push!(shads, :(shadows[1+$i]))
end
fval = if F <: Annotation
:(f.val)
else
:f
end

vals = Expr[]
consts = Expr[]
for arg in syms
push!(vals, :($arg isa Const ? $arg.val : $arg))
push!(consts, :($arg isa Const ? $arg : Const($arg)))
end

if length(shadows[1]) == 0
derivs = Expr[]
for arg in syms
push!(derivs, :($arg isa Const ? nothing : $arg))
end
return if ReturnPrimal
(; derivs = (x,), val = f(x.val))
return quote
Base.@_inline_meta
(; derivs = ($(derivs...),), val = $fval($(vals...)))
end
else
(x,)
return quote
Base.@_inline_meta
($(derivs...),)
end
end
end
if chunk == Val(0)
throw(ErrorException("Cannot differentiate with a batch size of 0"))
if CS == Val{0}
return quote
Base.@_inline_meta
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
end

gradtup = if chunk == nothing
resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1]))
exprs = Expr[]
primal = nothing
derivatives = Expr[]

res = values(resp[1])
dres = if x isa AbstractFloat
res[1]
else
res
end
if ReturnPrimal
((dres,), resp[2])
else
(dres,)
primmode = :(fm)
for (i, (arg, ty)) in enumerate(zip(sym, tys))
if ty isa Const
push!(derivatives, :(nothing))
continue
end
elseif chunk == Val(1)
if ReturnPrimal
rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1]))
dres1 = rp[1]
fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=#

res = ntuple(length(shadows[1]) - 1) do i
autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1]
argnum = length(ST.parameters[i])

argderivative = if CS == Nothing
dargs = Expr[]
for (j, arg2) in enumerate(sym)
if i == j
push!(dargs, :(BatchDuplicated($arg, $(shads[i]))))
else
push!(dargs, consts[j])
end
end
gres = if x isa AbstractFloat
dres1[1]
else
(dres1, res...)

df = :f
if F <: Enzyme.Duplicated
zeros = Expr[]
for i in 1:argnum
push!(zeros, :(f.dval))
end
df = :(BatchDuplicated(f.val, ($(zeros...),) ))
end
((gres,), rp[2])
else
res = ntuple(length(shadows[1])) do i
autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1]

resp = Symbol("resp_$i")
push!(exprs, quote
$resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...,))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end
(if x isa AbstractFloat
res[1]
else
res
end,)
end
else
if ReturnPrimal
rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1]))
dres1 = values(rp[1])
gres = if x isa AbstractFloat
dres1[1]

deriv = if ty <: AbstractFloat
:($resp[1][1])
else
fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=#
tmp = ntuple(length(shadows[1]) - 1) do i
values(
autodiff(
fm2,
f,
BatchDuplicated,
BatchDuplicated(x, shadows[1][i+1]),
)[1],
)
:(values($resp[1]))
end
deriv
elseif CS == Val{1}
subderivatives = Expr[]
for an in 1:argnum
dargs = Expr[]
for (j, arg2) in enumerate(sym)
if i == j
push!(dargs, :(Duplicated($arg, $(shads[i])[$an])))
else
push!(dargs, consts[j])
end
end
tupleconcat(dres1, tmp...)

resp = Symbol("resp_$i_$an")
push!(exprs, quote
$resp = autodiff($primmode, f, Duplicated, $(dargs...,))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end

deriv = if ty <: AbstractFloat
:($resp[1][1])
else
:($resp[1])
end

push!(subderivatives, deriv)
end
((gres,), rp[2])
:(($subderivatives...,))
else
tmp = ntuple(length(shadows[1])) do i
values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1])
subderivatives = Expr[]
for an in 1:argnum
dargs = Expr[]
for (j, arg2) in enumerate(sym)
if i == j
push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an])))
else
push!(dargs, consts[j])
end
end

resp = Symbol("resp_$i_$an")
push!(exprs, quote
$resp = autodiff($primmode, f, BatchDuplicated, $(dargs...,))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end

deriv = if ty <: AbstractFloat
:($resp[1][1])
else
:($resp[1])
end

push!(subderivatives, deriv)
end
res = tupleconcat(tmp...)
(if x isa AbstractFloat
res[1]
:(tupleconcat($(subderivatives...)))
end

deriv = if ty <: AbstractFloat
argderivative
else
tmp = Symbol("tmp_$i")
push!(exprs, :($tmp = $argderivative))
if ty <: AbstractArray
if argnum > 0
if $tmp[1] isa AbstractArray
inshape = :(size($(consts[1])))
outshape = :(size($tmp[1]))
# st : outshape x total inputs
:(tupstack($tmp, $outshape, inshape))
else
:(TupleArray($tmp, size($arg)))
else
:(TupleArray($tmp, size($arg)))
end
else
res
end,)
tmp
end
end
push!(derivatives, deriv)
end

cols = if ReturnPrimal
gradtup[1][1]
else
gradtup[1]
end
res = if x isa AbstractFloat
cols
elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
tupstack(cols, outshape, inshape)
elseif x isa AbstractArray
TupleArray(cols, size(x))
else
cols
# We weirdly asked for no derivatives
if ReturnPrimal && primal == nothing
primal = :($fval($(vals...)))
end
if ReturnPrimal
(; derivs = (res,), val = gradtup[2])

result = if ReturnPrimal
(; derivs = ($(derivatives...),), val = $primal)
else
(res,)
($(derivatives...),)
end

return quote
Base.@_inline_meta
$(exprs...)
$result
end
end

Expand Down
Loading

0 comments on commit 9ed8992

Please sign in to comment.