Skip to content

Commit

Permalink
Merge pull request #79 from ACEsuit/chho/fix_releaseinput
Browse files Browse the repository at this point in the history
Fix release input in `PooledSparseProduct` and tuple inputs
  • Loading branch information
cortner authored Dec 22, 2023
2 parents 6ca71ce + 20f14f5 commit 449fd0d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
19 changes: 16 additions & 3 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,28 @@ import LuxCore: AbstractExplicitLayer, initialparameters, initialstates

struct PooledSparseProductLayer{NB} <: AbstractExplicitLayer
basis::PooledSparseProduct{NB}
meta::Dict{String, Any}
release_input::Bool
end

lux(basis::PooledSparseProduct) = PooledSparseProductLayer(basis)
function lux(basis::PooledSparseProduct;
name = String(nameof(typeof(basis))),
meta = Dict{String, Any}("name" => name),
release_input = true)
@assert haskey(meta, "name")
return PooledSparseProductLayer(basis, meta, release_input)
end

initialparameters(rng::AbstractRNG, layer::PooledSparseProductLayer) =
NamedTuple()

initialstates(rng::AbstractRNG, layer::PooledSparseProductLayer) =
NamedTuple()

(l::PooledSparseProductLayer)(BB, ps, st) =
evaluate(l.basis, BB), st
(l::PooledSparseProductLayer)(BB, ps, st) = begin
out = evaluate(l.basis, BB)
if l.release_input
release!.(BB)
end
return out, st
end
2 changes: 1 addition & 1 deletion src/lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function evaluate(l::PolyLuxLayer, X, ps, st)
out = acquire!(st.pool, _outsym(X), _out_size(l.basis, X), _valtype(l.basis, X))
evaluate!(out, l.basis, X, ps)
if l.release_input
release!(X)
release!.(X)
end
return out, st
end
Expand Down
3 changes: 2 additions & 1 deletion test/test_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
print_tf(@test fdtest(F, dF, 0.0; verbose=false))
end
end

println()

@info("Testing evaluate")
for ntest = 1:30
x = randn(in_size)
Expand Down

0 comments on commit 449fd0d

Please sign in to comment.