Skip to content

Commit

Permalink
Merge pull request #68 from ACEsuit/fix_linear_pb
Browse files Browse the repository at this point in the history
fix LinearLayer vector input
  • Loading branch information
cortner authored Sep 23, 2023
2 parents 2ab7e7e + 5cadefc commit c9b2265
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ end
LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) = ( W = randn(rng, l.out_dim, l.in_dim), )
LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) = ( l.use_cache ? (pool = ArrayPool(FlexArrayCache), ) : (pool = ArrayPool(FlexArray), ))

# TODO: check whether we can do this without multiple dispatch on vec/mat without loss of performance
function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractVector, ps, st)
val = l(x, ps, st)
function pb(A)
return NoTangent(), NoTangent(), ps.W' * A[1], (W = A[1] * x',), NoTangent()
end
return val, pb
end

function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractMatrix, ps, st)
val = l(x, ps, st)
function pb(A)
Expand Down
17 changes: 17 additions & 0 deletions test/test_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
print_tf(@test Y1 Y2 Y3)
end
println()

@info("Testing rrule for vector input")
for ntest = 1:30
local x, val, u
x = randn(in_d)
bu = randn(in_d)
_BB(t) = x + t * bu
val, _ = l(x, ps, st)
u = randn(size(val))
F(t) = dot(u, l(_BB(t), ps, st)[1])
dF(t) = begin
val, pb = Zygote.pullback(LuxCore.apply, l, _BB(t), ps, st)
∂BB = pb((u, st))[2]
return dot(∂BB, bu)
end
print_tf(@test fdtest(F, dF, 0.0; verbose=false))
end
end

@info("Testing evaluate")
Expand Down

0 comments on commit c9b2265

Please sign in to comment.