diff --git a/CMakeLists.txt b/CMakeLists.txt index 42847f3e..75d29ff6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,7 +99,7 @@ string(REGEX REPLACE "^ | $" "" LIBS "${LIBS}") # tests enable_testing() -foreach(execid dense_layer input1d_layer) +foreach(execid input1d_layer dense_layer dense_network) add_executable(test_${execid} test/test_${execid}.f90) target_link_libraries(test_${execid} neural ${LIBS}) add_test(test_${execid} bin/test_${execid}) diff --git a/test/test_dense_network.f90 b/test/test_dense_network.f90 new file mode 100644 index 00000000..9df7e71b --- /dev/null +++ b/test/test_dense_network.f90 @@ -0,0 +1,69 @@ +program test_dense_network + use iso_fortran_env, only: stderr => error_unit + use nf, only: dense, input, network + implicit none + type(network) :: net + logical :: ok = .true. + + ! Minimal 2-layer network + net = network([ & + input(1), & + dense(1) & + ]) + + if (.not. size(net % layers) == 2) then + write(stderr, '(a)') 'dense network should have 2 layers.. failed' + ok = .false. + end if + + if (.not. all(net % output([0.]) == 0.5)) then + write(stderr, '(a)') & + 'dense network should output exactly 0.5 for input 0.. failed' + ok = .false. + end if + + training: block + real :: x(1), y(1) + real :: tolerance = 1e-3 + integer :: n + integer, parameter :: num_iterations = 1000 + + x = [0.123] + y = [0.765] + + do n = 1, num_iterations + call net % forward(x) + call net % backward(y) + call net % update(1.) + if (all(abs(net % output(x) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'dense network should converge in simple training.. failed' + ok = .false. + end if + + end block training + + ! A bit larger multi-layer network + net = network([ & + input(784), & + dense(30), & + dense(20), & + dense(10) & + ]) + + if (.not. size(net % layers) == 4) then + write(stderr, '(a)') 'dense network should have 4 layers.. failed' + ok = .false. + end if + + if (ok) then + print '(a)', 'test_dense_network: All tests passed.' + else + write(stderr, '(a)') 'test_dense_network: One or more tests failed.' + stop 1 + end if + +end program test_dense_network