Skip to content

Commit

Permalink
Fixed tolerances in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ivan-pi committed May 5, 2020
1 parent 0da4d57 commit ca439e6
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions src/tests/linalg/test_linalg.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
program test_linalg

use stdlib_experimental_error, only: check
use stdlib_experimental_kinds, only: sp, dp, qp, int8, int16, int32, int64
use stdlib_experimental_linalg, only: diag, eye, trace

implicit none

real(sp), parameter :: sptol = 1000 * epsilon(1._sp)
real(dp), parameter :: dptol = 1000 * epsilon(1._dp)
real(qp), parameter :: qptol = 1000 * epsilon(1._qp)

logical :: warn

! whether calls to check issue a warning
! or stop execution
warn = .false.

!
Expand Down Expand Up @@ -59,12 +69,12 @@ subroutine test_eye
msg="all(eye(5) == diag([(1,i=1,5)] failed.",warn=warn)

rye = eye(6)
call check(sum(rye - diag([(1.0_sp,i=1,6)])) < epsilon(rye), &
msg="sum(rye - diag([(1.0_sp,i=1,6)])) < epsilon(rye) failed.",warn=warn)
call check(sum(rye - diag([(1.0_sp,i=1,6)])) < sptol, &
msg="sum(rye - diag([(1.0_sp,i=1,6)])) < sptol failed.",warn=warn)

cye = eye(7)
call check(abs(trace(cye) - complex(7.0_sp,0.0_sp)) < epsilon(1.0_sp), &
msg="abs(trace(cye) - complex(7.0_sp,0.0_sp)) < epsilon(1.0_sp) failed.",warn=warn)
call check(abs(trace(cye) - complex(7.0_sp,0.0_sp)) < sptol, &
msg="abs(trace(cye) - complex(7.0_sp,0.0_sp)) < sptol failed.",warn=warn)
end subroutine

subroutine test_diag_rsp
Expand Down Expand Up @@ -95,8 +105,8 @@ subroutine test_diag_rsp_k
call check(all(a == b), &
msg="all(a == b) failed.",warn=warn)

call check(sum(diag(a,-1)) - (n-1) < epsilon(1.0_sp), &
msg="sum(diag(a,-1)) - (n-1) < epsilon(1.0_sp) failed.",warn=warn)
call check(sum(diag(a,-1)) - (n-1) < sptol, &
msg="sum(diag(a,-1)) - (n-1) < sptol failed.",warn=warn)

call check(all(a == transpose(diag([(1._sp,i=1,n-1)],1))), &
msg="all(a == transpose(diag([(1._sp,i=1,n-1)],1))) failed",warn=warn)
Expand Down Expand Up @@ -151,10 +161,10 @@ subroutine test_diag_csp
call check(all(a == b), &
msg="all(a == b) failed.",warn=warn)

call check(all(abs(real(diag(a)) - [(i,i=1,n)]) < epsilon(1.0_sp)), &
msg="all(abs(real(diag(a)) - [(i,i=1,n)]) < epsilon(1.0_sp))", warn=warn)
call check(all(abs(aimag(diag(a)) - [(1,i=1,n)]) < epsilon(1.0_sp)), &
msg="all(abs(aimag(diag(a)) - [(1,i=1,n)]) < epsilon(1.0_sp))", warn=warn)
call check(all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol), &
msg="all(abs(real(diag(a)) - [(i,i=1,n)]) < sptol)", warn=warn)
call check(all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol), &
msg="all(abs(aimag(diag(a)) - [(1,i=1,n)]) < sptol)", warn=warn)
end subroutine

subroutine test_diag_cdp
Expand Down Expand Up @@ -204,7 +214,6 @@ subroutine test_diag_int16
msg="all(diag(a) == pack(a,mask))", warn=warn)
call check(all(diag(diag(a)) == merge(a,0_int16,mask)), &
msg="all(diag(diag(a)) == merge(a,0_int16,mask)) failed.", warn=warn)
a = unpack(int([1,2,3,4],int16),eye(n)==1,a)
end subroutine
subroutine test_diag_int32
integer, parameter :: n = 3
Expand Down Expand Up @@ -261,8 +270,8 @@ subroutine test_trace_rsp
integer :: i
write(*,*) "test_trace_rsp"
a = reshape([(i,i=1,n**2)],[n,n])
call check(abs(trace(a) - sum(diag(a))) < epsilon(1.0_sp), &
msg="abs(trace(a) - sum(diag(a))) < epsilon(1.0_sp) failed.",warn=warn)
call check(abs(trace(a) - sum(diag(a))) < sptol, &
msg="abs(trace(a) - sum(diag(a))) < sptol failed.",warn=warn)
end subroutine

subroutine test_trace_rsp_nonsquare
Expand All @@ -278,8 +287,8 @@ subroutine test_trace_rsp_nonsquare
a = reshape([(i,i=1,n*(n+1))],[n,n+1])
ans = sum([1._sp,6._sp,11._sp,16._sp])

call check(abs(trace(a) - ans) < epsilon(1.0_sp), &
msg="abs(trace(a) - ans) < epsilon(1.0_sp) failed.",warn=warn)
call check(abs(trace(a) - ans) < sptol, &
msg="abs(trace(a) - ans) < sptol failed.",warn=warn)
end subroutine

subroutine test_trace_rdp
Expand All @@ -288,8 +297,8 @@ subroutine test_trace_rdp
integer :: i
write(*,*) "test_trace_rdp"
a = reshape([(i,i=1,n**2)],[n,n])
call check(abs(trace(a) - sum(diag(a))) < epsilon(1.0_dp), &
msg="abs(trace(a) - sum(diag(a))) < epsilon(1.0_dp) failed.",warn=warn)
call check(abs(trace(a) - sum(diag(a))) < dptol, &
msg="abs(trace(a) - sum(diag(a))) < dptol failed.",warn=warn)
end subroutine

subroutine test_trace_rdp_nonsquare
Expand All @@ -305,8 +314,8 @@ subroutine test_trace_rdp_nonsquare
a = reshape([(i**2,i=1,n*(n-1))],[n,n-1])
ans = sum([1._dp,36._dp,121._dp])

call check(abs(trace(a) - ans) < epsilon(1.0_dp), &
msg="abs(trace(a) - ans) < epsilon(1.0_sp) failed.",warn=warn)
call check(abs(trace(a) - ans) < dptol, &
msg="abs(trace(a) - ans) < dptol failed.",warn=warn)
end subroutine

subroutine test_trace_rqp
Expand All @@ -315,8 +324,8 @@ subroutine test_trace_rqp
integer :: i
write(*,*) "test_trace_rqp"
a = reshape([(i,i=1,n**2)],[n,n])
call check(abs(trace(a) - sum(diag(a))) < epsilon(1.0_qp), &
msg="abs(trace(a) - sum(diag(a))) < epsilon(1.0_qp) failed.",warn=warn)
call check(abs(trace(a) - sum(diag(a))) < qptol, &
msg="abs(trace(a) - sum(diag(a))) < qptol failed.",warn=warn)
end subroutine


Expand All @@ -336,8 +345,8 @@ subroutine test_trace_csp
b = re + im*i_

! tr(A + B) = tr(A) + tr(B)
call check(abs(trace(a+b) - (trace(a) + trace(b))) < 10*epsilon(1.0_sp), &
msg="abs(trace(a+b) - (trace(a) + trace(b))) < 10*epsilon(1.0_sp) failed.",warn=warn)
call check(abs(trace(a+b) - (trace(a) + trace(b))) < sptol, &
msg="abs(trace(a+b) - (trace(a) + trace(b))) < sptol failed.",warn=warn)
end subroutine

subroutine test_trace_cdp
Expand All @@ -350,8 +359,8 @@ subroutine test_trace_cdp
a = reshape([(j + (n**2 - (j-1))*i_,j=1,n**2)],[n,n])
ans = complex(15,15) !(1 + 5 + 9) + (9 + 5 + 1)i

call check(abs(trace(a) - ans) < epsilon(1.0_dp), &
msg="abs(trace(a) - ans) < epsilon(1.0_dp) failed.",warn=warn)
call check(abs(trace(a) - ans) < dptol, &
msg="abs(trace(a) - ans) < dptol failed.",warn=warn)
end subroutine

subroutine test_trace_cqp
Expand All @@ -360,8 +369,8 @@ subroutine test_trace_cqp
complex(qp), parameter :: i_ = complex(0,1)
write(*,*) "test_trace_cqp"
a = 3*eye(n) + 4*eye(n)*i_ ! pythagorean triple
call check(abs(trace(a)) - 3*5.0_qp < epsilon(1.0_qp), &
msg="abs(trace(a)) - 3*5.0_qp < epsilon(1.0_qp) failed.",warn=warn)
call check(abs(trace(a)) - 3*5.0_qp < qptol, &
msg="abs(trace(a)) - 3*5.0_qp < qptol failed.",warn=warn)
end subroutine


Expand Down

0 comments on commit ca439e6

Please sign in to comment.