Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi arg fwd gradient #1952

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 221 additions & 101 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1794,16 +1794,28 @@ end
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)

function create_shadows(::Nothing, x)
return (onehot(x),)
end

function create_shadows(::Val{1}, x)
return (onehot(x),)
end

function create_shadows(::Val{chunk}, x) where {chunk}
return (chunkedonehot(x, Val(chunk)),)
@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N}
args = Union{Symbol,Expr}[:x]
tys = Type[X]
for i in 1:N
push!(args, :(vargs[$i]))
push!(tys, vargs[i])
end

exprs = Union{Symbol,Expr}[]
for (arg, ty) in zip(args, tys)
if ty <: Enzyme.Const
push!(exprs, :(nothing))
elseif ChunkTy == Nothing || ChunkTy == Val{1}
push!(exprs, :(onehot($arg)))
else
push!(exprs, :(chunkedonehot($arg, chunk)))
end
end
return quote
Base.@_inline_meta
($(exprs...),)
end
end

struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N}
Expand Down Expand Up @@ -1890,7 +1902,7 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1))
(derivs = ([3.0, 2.0],), val = 6.0)
```

For functions which return an AbstractArray or scalar, this function will return an AbstracttArray
For functions which return an AbstractArray or scalar, this function will return an AbstractArray
whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made
about the type of the AbstractArray returned by this function (which may or may not be the same
as the input AbstractArray if provided).
Expand All @@ -1905,119 +1917,227 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0])
# output
([3.0 2.0 0.0; 0.0 1.0 1.0],)
```

This function supports multiple arguments and computes the gradient with respect to each

```jldoctest gradfwd2
mul(x, y) = x[1]*y[2] + x[2]*y[1]

gradient(Forward, f, [2.0, 3.0], [2.7, 3.1])

# output

([3.1, 2.7],[3.0, 2.0])
```

This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map.

```jldoctest gradfwd2
gradient(Forward, f, [2.0, 3.0], Const([2.7, 3.1]))

# output

([3.1, 2.7], nothing)
```

```jldoctest gradfwd2
gradient(Forward, f, Const([2.0, 3.0]), [2.7, 3.1])

# output

([3.1, 2.7], nothing)
```
"""
@inline function gradient(
@generated function gradient(
fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity},
f,
x;
f::F,
x::ty_0,
args::Vararg{Any,N};
chunk::CS = nothing,
shadows = create_shadows(chunk, x),
) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS}
if length(shadows[1]) == 0
return if ReturnPrimal
(; derivs = (x,), val = f(x.val))
else
(x,)
end
end
if chunk == Val(0)
throw(ErrorException("Cannot differentiate with a batch size of 0"))
shadows::ST = create_shadows(chunk, x, args...),
) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST, ty_0, N}

syms = Union{Symbol,Expr}[:x]
shads = Union{Symbol,Expr}[:(shadows[1])]
tys = Type[ty_0]
for i in 1:N
push!(syms, :(args[$i]))
push!(tys, args[i])
push!(shads, :(shadows[1+$i]))
end
fval = if F <: Annotation
:(f.val)
else
:f
end

gradtup = if chunk == nothing
resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1]))
vals = Expr[]
consts = Expr[]
for arg in syms
push!(vals, :($arg isa Const ? $arg.val : $arg))
push!(consts, :($arg isa Const ? $arg : Const($arg)))
end

res = values(resp[1])
dres = if x isa AbstractFloat
res[1]
else
res
if CS == Val{0}
return quote
Base.@_inline_meta
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
if ReturnPrimal
((dres,), resp[2])
else
(dres,)
end

exprs = Expr[]
primal = nothing
derivatives = Union{Symbol,Expr}[]

primmode = :(fm)
for (i, (arg, ty)) in enumerate(zip(syms, tys))
if ty <: Const
push!(derivatives, :(nothing))
continue
end
elseif chunk == Val(1)
if ReturnPrimal
rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1]))
dres1 = rp[1]
fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=#

res = ntuple(length(shadows[1]) - 1) do i
autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1]
argnum = length(ST.parameters[i].parameters)

argderivative = if argnum == 0
vals[i]
elseif CS == Nothing
dargs = Expr[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(BatchDuplicated($arg, $(shads[i]))))
else
push!(dargs, consts[j])
end
end
gres = if x isa AbstractFloat
dres1[1]
else
(dres1, res...)

df = :f
if F <: Enzyme.Duplicated
zeros = Expr[]
for i in 1:argnum
push!(zeros, :(f.dval))
end
df = :(BatchDuplicated(f.val, ($(zeros...),) ))
end
((gres,), rp[2])
else
res = ntuple(length(shadows[1])) do i
autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1]

resp = Symbol("resp_$i")
push!(exprs, quote
$resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end
(if x isa AbstractFloat
res[1]
else
res
end,)
end
else
if ReturnPrimal
rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1]))
dres1 = values(rp[1])
gres = if x isa AbstractFloat
dres1[1]

deriv = if ty <: AbstractFloat
:($resp[1][1])
else
fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=#
tmp = ntuple(length(shadows[1]) - 1) do i
values(
autodiff(
fm2,
f,
BatchDuplicated,
BatchDuplicated(x, shadows[1][i+1]),
)[1],
)
:(values($resp[1]))
end
deriv
elseif CS == Val{1}
subderivatives = Expr[]
for an in 1:argnum
dargs = Expr[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(Duplicated($arg, $(shads[i])[$an])))
else
push!(dargs, consts[j])
end
end

resp = Symbol("resp_$i"*"_"*string(an))
push!(exprs, quote
$resp = autodiff($primmode, f, Duplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end

deriv = if ty <: AbstractFloat
:($resp[1][1])
else
:($resp[1])
end
tupleconcat(dres1, tmp...)

push!(subderivatives, deriv)
end
((gres,), rp[2])
:(($(subderivatives...),))
else
tmp = ntuple(length(shadows[1])) do i
values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1])
subderivatives = Expr[]
for an in 1:argnum
dargs = Expr[]
for (j, arg2) in enumerate(syms)
if i == j
push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an])))
else
push!(dargs, consts[j])
end
end

resp = Symbol("resp_$i"*"_"*string(an))
push!(exprs, quote
$resp = autodiff($primmode, f, BatchDuplicated, $(dargs...))
end)
if ReturnPrimal && primal == nothing
primal = :($resp[2])
primmode = NoPrimal(fm())
end

deriv = if ty <: AbstractFloat
:($resp[1][1])
else
:(values($resp[1]))
end

push!(subderivatives, deriv)
end
res = tupleconcat(tmp...)
(if x isa AbstractFloat
res[1]
:(tupleconcat($(subderivatives...)))
end

deriv = if ty <: AbstractFloat
argderivative
else
tmp = Symbol("tmp_$i")
push!(exprs, :($tmp = $argderivative))
if ty <: AbstractArray
if argnum > 0
quote
if $tmp[1] isa AbstractArray
inshape = size($(vals[1]))
outshape = size($tmp[1])
# st : outshape x total inputs
tupstack($tmp, outshape, inshape)
else
TupleArray($tmp, size($arg))
end
end
else
:(TupleArray($tmp, size($arg)))
end
else
res
end,)
tmp
end
end
push!(derivatives, deriv)
end

cols = if ReturnPrimal
gradtup[1][1]
else
gradtup[1]
end
res = if x isa AbstractFloat
cols
elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
tupstack(cols, outshape, inshape)
elseif x isa AbstractArray
TupleArray(cols, size(x))
else
cols
# We weirdly asked for no derivatives
if ReturnPrimal && primal == nothing
primal = :($fval($(vals...)))
end
if ReturnPrimal
(; derivs = (res,), val = gradtup[2])

result = if ReturnPrimal
:((; derivs = ($(derivatives...),), val = $primal))
else
(res,)
:(($(derivatives...),))
end

return quote
Base.@_inline_meta
$(exprs...)
$result
end
end

Expand Down
Loading
Loading