Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ExpandingMan committed Sep 27, 2024
1 parent 20a9dba commit 10d2246
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ module EnzymeStaticArraysExt
using StaticArrays
using Enzyme

#TODO: it would be better if we could always return SArray from gradient directly,
# to retain shape information. for now we are at least able to convert
@inline function Base.convert(::Type{SArray}, tpa::Enzyme.TupleArray{T,S,L,N}) where {T,S,L,N}
SArray{Tuple{S...},T,N,L}(tpa.data)
end
@inline Base.convert(::Type{StaticArray}, tpa::Enzyme.TupleArray) = convert(SArray, tpa)

@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape)
reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...))
end
Expand Down
45 changes: 45 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2854,6 +2854,51 @@ end
@test dx[1] 0
@test dx[2] 30
@test dx[3] 0

f0 = x -> sum(2*x)
f1 = x -> @SVector Float64[x[2], 2*x[2]]
f2 = x -> @SMatrix Float64[x[2] x[1]; 2*x[2] 2*x[1]]

x = @SVector Float64[1, 2]

dx = gradient(Forward, f0, x)[1]
@test dx isa Enzyme.TupleArray
@test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works
@test gradient(Forward, f1, x)[1] isa SMatrix
@test gradient(Forward, f1, x)[1] == [0 1.0; 0 2.0]
@test jacobian(Forward, f2, x)[1] isa SArray
@test jacobian(Forward, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2))

x = @SMatrix Float64[1 2; 3 4]

dx = gradient(Forward, f0, x)[1]
@test dx isa Enzyme.TupleArray
@test convert(SArray, dx) == fill(2.0, (2,2))
@test gradient(Forward, f1, x)[1] isa SArray
@test gradient(Forward, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2))
@test jacobian(Forward, f2, x)[1] isa SArray
@test jacobian(Forward, f2, x)[1] == reshape(
Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2),
)

x = @SVector Float64[1, 2]

dx = gradient(Reverse, f0, x)[1]
@test dx isa SVector
@test convert(SArray, dx) == [2.0, 2.0] # test to make sure conversion works
@test_broken gradient(Reverse, f1, x)[1] isa SMatrix
@test_broken gradient(Reverse, f1, x)[1] == [0 1.0; 0 2.0]
@test_broken jacobian(Reverse, f2, x)[1] isa SArray
@test_broken jacobian(Reverse, f2, x)[1] == reshape(Float64[0,0,1,2,1,2,0,0], (2,2,2))

x = @SMatrix Float64[1 2; 3 4]

@test_broken gradient(Reverse, f1, x)[1] isa SArray
@test_broken gradient(Reverse, f1, x)[1] == reshape(Float64[0,0,1,2,0,0,0,0], (2,2,2))
@test_broken jacobian(Reverse, f2, x)[1] isa SArray
@test_broken jacobian(Reverse, f2, x)[1] == reshape(
Float64[0,0,1,2,1,2,0,0,0,0,0,0,0,0,0,0], (2,2,2,2),
)
end

function unstable_fun(A0)
Expand Down

0 comments on commit 10d2246

Please sign in to comment.