Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Jun 21, 2024
1 parent 0a5e3f4 commit 34f7a53
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions test/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ using TensorOperations: IndexError
# test different versions of in-place methods,
# with changing element type and with nontrivial strides

@testset "tensorcopy!" begin
@testset "tensorcopy! with backend $b" for b in (TensorOperations.StridedNative(),
TensorOperations.BaseView(),
TensorOperations.BaseCopy())
Abig = randn(Float64, (30, 30, 30, 30))
A = view(Abig, 1 .+ 3 * (0:9), 2 .+ 2 * (0:6), 5 .+ 4 * (0:6), 4 .+ 3 * (0:8))
p = (3, 1, 4, 2)
Expand All @@ -142,15 +144,18 @@ using TensorOperations: IndexError
Acopy = tensorcopy(A, 1:4)
Ccopy = tensorcopy(C, 1:4)
pA = (p, ())
tensorcopy!(C, A, pA, false)
tensorcopy!(Ccopy, Acopy, pA, false)
@test C Ccopy
@test_throws IndexError tensorcopy!(C, A, ((1, 2, 3), ()), false)
@test_throws DimensionMismatch tensorcopy!(C, A, ((1, 2, 3, 4), ()), false)
@test_throws IndexError tensorcopy!(C, A, ((1, 2, 2, 3), ()), false)
α = randn(Float64)
tensorcopy!(C, A, pA, false, α, b)
tensorcopy!(Ccopy, Acopy, pA, false, 1.0, b)
@test C α * Ccopy
@test_throws IndexError tensorcopy!(C, A, ((1, 2, 3), ()), false, 1.0, b)
@test_throws DimensionMismatch tensorcopy!(C, A, ((1, 2, 3, 4), ()), false, 1.0, b)
@test_throws IndexError tensorcopy!(C, A, ((1, 2, 2, 3), ()), false, 1.0, b)
end

@testset "tensoradd!" begin
@testset "tensoradd! with backend $b" for b in (TensorOperations.StridedNative(),
TensorOperations.BaseView(),
TensorOperations.BaseCopy())
Abig = randn(Float64, (30, 30, 30, 30))
A = view(Abig, 1 .+ 3 * (0:9), 2 .+ 2 * (0:6), 5 .+ 4 * (0:6), 4 .+ 3 * (0:8))
p = (3, 1, 4, 2)
Expand All @@ -160,15 +165,18 @@ using TensorOperations: IndexError
Ccopy = tensorcopy(1:4, C, 1:4)
α = randn(Float64)
β = randn(Float64)
tensoradd!(C, A, (p, ()), false, α, β)
tensoradd!(C, A, (p, ()), false, α, β, b)
Ccopy = β * Ccopy + α * Acopy
@test C Ccopy
@test_throws IndexError tensoradd!(C, A, ((1, 2, 3), ()), false, 1.2, 0.5)
@test_throws DimensionMismatch tensoradd!(C, A, ((1, 2, 3, 4), ()), false, 1.2, 0.5)
@test_throws IndexError tensoradd!(C, A, ((1, 1, 2, 3), ()), false, 1.2, 0.5)
@test_throws IndexError tensoradd!(C, A, ((1, 2, 3), ()), false, 1.2, 0.5, b)
@test_throws DimensionMismatch tensoradd!(C, A, ((1, 2, 3, 4), ()), false, 1.2, 0.5,
b)
@test_throws IndexError tensoradd!(C, A, ((1, 1, 2, 3), ()), false, 1.2, 0.5, b)
end

@testset "tensortrace!" begin
@testset "tensortrace! with backend $b" for b in (TensorOperations.StridedNative(),
TensorOperations.BaseView(),
TensorOperations.BaseCopy())
Abig = rand(Float64, (30, 30, 30, 30))
A = view(Abig, 1 .+ 3 * (0:8), 2 .+ 2 * (0:14), 5 .+ 4 * (0:6), 7 .+ 2 * (0:8))
Bbig = rand(ComplexF64, (50, 50))
Expand All @@ -177,22 +185,25 @@ using TensorOperations: IndexError
Bcopy = tensorcopy(B, 1:2)
α = randn(Float64)
β = randn(Float64)
tensortrace!(B, A, ((2, 3), ()), ((1,), (4,)), false, α, β)
tensortrace!(B, A, ((2, 3), ()), ((1,), (4,)), false, α, β, b)
Bcopy = β * Bcopy
for i in 1 .+ (0:8)
Bcopy += α * view(A, i, :, :, i)
end
@test B Bcopy
@test_throws IndexError tensortrace!(B, A, ((1,), ()), ((2,), (3,)), false, α, β)
@test_throws IndexError tensortrace!(B, A, ((1,), ()), ((2,), (3,)), false, α, β, b)
@test_throws DimensionMismatch tensortrace!(B, A, ((1, 4), ()), ((2,), (3,)), false,
α, β)
α, β, b)
@test_throws IndexError tensortrace!(B, A, ((1, 4), ()), ((1, 1), (4,)), false, α,
β)
β, b)
@test_throws DimensionMismatch tensortrace!(B, A, ((1, 4), ()), ((1,), (3,)), false,
α, β)
α, β, b)
end

@testset "tensorcontract!" begin
@testset "tensorcontract! with backend $b" for b in (TensorOperations.StridedNative(),
TensorOperations.StridedBLAS(),
TensorOperations.BaseView(),
TensorOperations.BaseCopy())
Abig = rand(Float64, (30, 30, 30, 30))
A = view(Abig, 1 .+ 3 * (0:8), 2 .+ 2 * (0:14), 5 .+ 4 * (0:6), 7 .+ 2 * (0:8))
Bbig = rand(ComplexF64, (50, 50, 50))
Expand All @@ -211,23 +222,23 @@ using TensorOperations: IndexError
end
end
tensorcontract!(C, A, ((4, 1), (2, 3)), false, B, ((3, 1), (2,)), true,
((1, 2, 3), ()), α, β)
((1, 2, 3), ()), α, β, b)
@test C Ccopy
@test_throws IndexError tensorcontract!(C,
A, ((4, 1), (2, 4)), false,
B, ((1, 3), (2,)), false,
((1, 2, 3), ()), α, β)
((1, 2, 3), ()), α, β, b)
@test_throws IndexError tensorcontract!(C,
A, ((4, 1), (2, 3)), false,
B, ((1, 3), ()), false,
((1, 2, 3), ()), α, β)
((1, 2, 3), ()), α, β, b)
@test_throws IndexError tensorcontract!(C,
A, ((4, 1), (2, 3)), false,
B, ((1, 3), (2,)), false,
((1, 2), ()), α, β)
((1, 2), ()), α, β, b)
@test_throws DimensionMismatch tensorcontract!(C,
A, ((4, 1), (2, 3)), false,
B, ((1, 3), (2,)), false,
((1, 3, 2), ()), α, β)
((1, 3, 2), ()), α, β, b)
end
end

0 comments on commit 34f7a53

Please sign in to comment.