forked from modern-fortran/neural-fortran
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_reshape_layer.f90
53 lines (42 loc) · 1.51 KB
/
test_reshape_layer.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
program test_reshape_layer
use iso_fortran_env, only: stderr => error_unit
use nf, only: input, network, reshape_layer => reshape
use nf_datasets, only: download_and_unpack, keras_reshape_url
implicit none
type(network) :: net
real, allocatable :: sample_input(:), output(:,:,:)
integer, parameter :: output_shape(3) = [3, 32, 32]
integer, parameter :: input_size = product(output_shape)
character(*), parameter :: keras_reshape_path = 'keras_reshape.h5'
logical :: file_exists
logical :: ok = .true.
! Create the network
net = network([ &
input(input_size), &
reshape_layer(output_shape) &
])
if (.not. size(net % layers) == 2) then
write(stderr, '(a)') 'the network should have 2 layers.. failed'
ok = .false.
end if
! Initialize test data
allocate(sample_input(input_size))
call random_number(sample_input)
! Propagate forward and get the output
call net % forward(sample_input)
call net % layers(2) % get_output(output)
if (.not. all(shape(output) == output_shape)) then
write(stderr, '(a)') 'the reshape layer produces expected output shape.. failed'
ok = .false.
end if
if (.not. all(reshape(sample_input, output_shape) == output)) then
write(stderr, '(a)') 'the reshape layer produces expected output values.. failed'
ok = .false.
end if
if (ok) then
print '(a)', 'test_reshape_layer: All tests passed.'
else
write(stderr, '(a)') 'test_reshape_layer: One or more tests failed.'
stop 1
end if
end program test_reshape_layer