Skip to content

Commit 11c3e5a

Browse files
committed
Add a few tests for a dense network
1 parent 6ccb92a commit 11c3e5a

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ string(REGEX REPLACE "^ | $" "" LIBS "${LIBS}")
9999

100100
# tests
101101
enable_testing()
102-
foreach(execid dense_layer input1d_layer)
102+
foreach(execid input1d_layer dense_layer dense_network)
103103
add_executable(test_${execid} test/test_${execid}.f90)
104104
target_link_libraries(test_${execid} neural ${LIBS})
105105
add_test(test_${execid} bin/test_${execid})

test/test_dense_network.f90

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
program test_dense_network
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf, only: dense, input, network
4+
implicit none
5+
type(network) :: net
6+
logical :: ok = .true.
7+
8+
! Minimal 2-layer network
9+
net = network([ &
10+
input(1), &
11+
dense(1) &
12+
])
13+
14+
if (.not. size(net % layers) == 2) then
15+
write(stderr, '(a)') 'dense network should have 2 layers.. failed'
16+
ok = .false.
17+
end if
18+
19+
if (.not. all(net % output([0.]) == 0.5)) then
20+
write(stderr, '(a)') &
21+
'dense network should output exactly 0.5 for input 0.. failed'
22+
ok = .false.
23+
end if
24+
25+
training: block
26+
real :: x(1), y(1)
27+
real :: tolerance = 1e-3
28+
integer :: n
29+
integer, parameter :: num_iterations = 1000
30+
31+
x = [0.123]
32+
y = [0.765]
33+
34+
do n = 1, num_iterations
35+
call net % forward(x)
36+
call net % backward(y)
37+
call net % update(1.)
38+
if (all(abs(net % output(x) - y) < tolerance)) exit
39+
end do
40+
41+
if (.not. n <= num_iterations) then
42+
write(stderr, '(a)') &
43+
'dense network should converge in simple training.. failed'
44+
ok = .false.
45+
end if
46+
47+
end block training
48+
49+
! A bit larger multi-layer network
50+
net = network([ &
51+
input(784), &
52+
dense(30), &
53+
dense(20), &
54+
dense(10) &
55+
])
56+
57+
if (.not. size(net % layers) == 4) then
58+
write(stderr, '(a)') 'dense network should have 4 layers.. failed'
59+
ok = .false.
60+
end if
61+
62+
if (ok) then
63+
print '(a)', 'test_dense_network: All tests passed.'
64+
else
65+
write(stderr, '(a)') 'test_dense_network: One or more tests failed.'
66+
stop 1
67+
end if
68+
69+
end program test_dense_network

0 commit comments

Comments
 (0)