Skip to content

Commit

Permalink
new inv
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Apr 18, 2021
1 parent e2376d3 commit 6e83d8d
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/autodiff/gradfunc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ end
newargs[iloss] = :(GVar($newres[$iloss], one($newres[$iloss])))
quote
$newres = f(args...; kwargs...)
grad.(NiLangCore.wrap_tuple((~f)($(newargs...); kwargs...)))
grad((~f)($(newargs...); kwargs...))
end
end

Expand Down
13 changes: 13 additions & 0 deletions src/autodiff/instructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ end
@nograd (identity)(a!::Real, b::GVar)
@nograd (identity)(a!::GVar, b::Real)

# inv
@eval @i @inline function (inv)(out!::GVar{T}, y::GVar) where T
out!.x -= inv(y.x)
@routine @invcheckoff begin
@zeros T a1
a1 += y.x ^ 2
end
y.g -= out!.g / a1
~@routine
end
@nograd (inv)(a!::Real, b::GVar)
@nograd (inv)(a!::GVar, b::Real)

# +- (triple)
@i @inline function (+)(out!::GVar, x::GVar, y::GVar)
out!.x -= x.x + y.x
Expand Down
10 changes: 6 additions & 4 deletions src/autodiff/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export jacobian, jacobian_repeat

wrap_tuple(x, args) = length(args) == 1 ? (x,) : x

"""
jacobian_repeat(f, args...; iin::Int, iout::Int=iin, kwargs...)
Expand All @@ -10,11 +12,11 @@ function jacobian_repeat(f, args...; iin::Int, iout::Int=iin, kwargs...)
_check_input(args, iin, iout)
N = length(args[iout])
res = zeros(eltype(args[iin]), length(args[iin]), N)
xargs = NiLangCore.wrap_tuple(f(args...; kwargs...))
xargs = wrap_tuple(f(args...; kwargs...), args)
for i = 1:N
gxargs = GVar.(xargs)
@inbounds gxargs[iout][i] = GVar(value(gxargs[iout][i]), one(eltype(xargs[iout])))
@inbounds res[:,i] .= vec(grad.(NiLangCore.wrap_tuple((~f)(gxargs...; kwargs...))[iin]))
@inbounds res[:,i] .= vec(grad.(wrap_tuple((~f)(gxargs...; kwargs...), gxargs)[iin]))
end
return res
end
Expand All @@ -30,10 +32,10 @@ One can use key word arguments `iin` and `iout` to specify the input and output
"""
function jacobian(f, args...; iin::Int, iout::Int=iin, kwargs...)
_check_input(args, iin, iout)
args = NiLangCore.wrap_tuple(f(args...; kwargs...))
args = wrap_tuple(f(args...; kwargs...), args)
ABT = AutoBcast{eltype(args[iout]), length(args[iout])}
_args = map(i-> i==iout ? wrap_jacobian(ABT, args[i]) : wrap_bcastgrad(ABT, args[i]), 1:length(args))
_args = NiLangCore.wrap_tuple((~f)(_args...; kwargs...))
_args = wrap_tuple((~f)(_args...; kwargs...), args)
out = zeros(eltype(args[iin]), length(args[iin]), length(args[iout]))
for i=1:length(args[iin])
@inbounds out[i,:] .= grad(_args[iin][i]).x
Expand Down
9 changes: 9 additions & 0 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ end
~@routine
end

@i @inline function :(+=)(inv)(y!::Complex{T}, b::Complex{T}) where T
@routine @invcheckoff begin
b2 zero(T)
b2 += abs2(b)
end
y! += b' / b2
~@routine
end

@i @inline function (exp)(y!::Complex{T}, x::Complex{T}) where T
@routine @invcheckoff begin
@zeros T s c expn
Expand Down
3 changes: 2 additions & 1 deletion test/autodiff/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ end
(opm(identity), (x,y)), (opm(+), (x, y, z)),
(opm(-), (x, y, z)), (opm(*), (x, y, z)),
(opm(/), (x, y, z)), (opm(^), (x, y, r)),
(opm(exp), (x, y)), (opm(log), (x, y))
(opm(exp), (x, y)), (opm(log), (x, y)),
(opm(inv), (x, y))
]
@test ccheck_grad(subop, args; verbose=true, iloss=1)
r1 = subop(args...)
Expand Down
1 change: 1 addition & 0 deletions test/autodiff/instructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Test
@test check_grad(opm(-), (1.0, 2.0, 2.0); verbose=true, iloss=1)
@test check_grad(opm(^), (1.0, 2.0, 2); verbose=true, iloss=1)
@test check_grad(opm(^), (1.0, 2.0, 2.0); verbose=true, iloss=1)
@test check_grad(opm(inv), (1.0, 2.0); verbose=true, iloss=1)
@test check_grad(opm(sqrt), (1.0, 2.0); verbose=true, iloss=1)
@test check_grad(opm(abs), (1.0, -2.0); verbose=true, iloss=1)
@test check_grad(opm(abs2), (1.0, -2.0); verbose=true, iloss=1)
Expand Down
2 changes: 0 additions & 2 deletions test/instructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ end
@test x === Int32(3)
@instr DEC(x)
@test x === Int32(2)
@instr (x,) |> INC |> DEC |> DEC
@test x === Int32(1)
end

@testset "HADAMARD" begin
Expand Down

0 comments on commit 6e83d8d

Please sign in to comment.