diff --git a/src/autodiff/gradfunc.jl b/src/autodiff/gradfunc.jl index 94bbb91..0dc1b62 100644 --- a/src/autodiff/gradfunc.jl +++ b/src/autodiff/gradfunc.jl @@ -48,6 +48,41 @@ end (~protectf(g).f)(args...; kwargs...) end +""" + gradient(Val(iloss), f, args::Tuple; kwargs...) + gradient(f, args; iloss::Int, kwargs...) + +Calculate the gradient of `f(args...; kwargs...)` for reversible function `f` with regard to +input `args`. The integer value `iloss` specifies the position of `loss` in `args`. + +!!! note + `iloss=1` is specially optimized, so putting the loss as the first parameter can avoid potential overhead. + +# Examples + +```jldoctest; setup=:(using NiLang) +X = rand(2, 2) +grads = NiLang.AD.gradient(Val(1), i_norm2, (0.0, X)) +grads[2] ≈ 2 .* X + +# output +true +``` + +Note that `gradient` calculation is disabled for container with integers: + +```jldoctest; setup=:(using NiLang) +X = rand(Int, 2, 2) +NiLang.AD.gradient(Val(1), i_sum, (0, X))[2] + +# output +2×2 Matrix{Int64}: + 0 0 + 0 0 +``` +""" +gradient + @generated function gradient(::Val{iloss}, f, args::NTuple{N,Any}; kwargs...) where {iloss,N} newres = gensym() newargs = Any[:(GVar($newres[$i])) for i=1:N]