diff --git a/CMakeLists.txt b/CMakeLists.txt index 15a294fd..4d40fd57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,8 @@ add_library(neural src/nf/nf_datasets_mnist_submodule.f90 src/nf/nf_dense_layer.f90 src/nf/nf_dense_layer_submodule.f90 + src/nf/nf_flatten_layer.f90 + src/nf/nf_flatten_layer_submodule.f90 src/nf/nf_input1d_layer.f90 src/nf/nf_input1d_layer_submodule.f90 src/nf/nf_input3d_layer.f90 @@ -102,13 +104,13 @@ string(REGEX REPLACE "^ | $" "" LIBS "${LIBS}") # tests enable_testing() -foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer maxpool2d_layer dense_network conv2d_network) +foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer maxpool2d_layer flatten_layer dense_network conv2d_network) add_executable(test_${execid} test/test_${execid}.f90) target_link_libraries(test_${execid} neural ${LIBS}) add_test(test_${execid} bin/test_${execid}) endforeach() -foreach(execid mnist simple sine) +foreach(execid cnn mnist simple sine) add_executable(${execid} example/${execid}.f90) target_link_libraries(${execid} neural ${LIBS}) endforeach() diff --git a/README.md b/README.md index ae13a2c0..b1bd7205 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). | Dense (fully-connected) | `dense` | `input` (1-d) | 1 | ✅ | ✅ | | Convolutional (2-d) | `conv2d` | `input` (3-d), `conv2d`, `maxpool2d` | 3 | ✅ | ❌ | | Max-pooling (2-d) | `maxpool2d` | `input` (3-d), `conv2d`, `maxpool2d` | 3 | ✅ | ❌ | +| Flatten | `flatten` | `input` (3-d), `conv2d`, `maxpool2d` | 1 | ✅ | ✅ | ## Getting started @@ -172,9 +173,13 @@ to run the tests. The easiest way to get a sense of how to use neural-fortran is to look at examples, in increasing level of complexity: -1. [simple](example/simple.f90): Approximating a simple, constant data relationship +1. [simple](example/simple.f90): Approximating a simple, constant data + relationship 2. [sine](example/sine.f90): Approximating a sine function -3. [mnist](example/mnist.f90): Hand-written digit recognition using the MNIST dataset +3. [mnist](example/mnist.f90): Hand-written digit recognition using the MNIST + dataset +4. [cnn](example/cnn.f90): Creating and running forward a simple CNN using + `input`, `conv2d`, `maxpool2d`, `flatten`, and `dense` layers. The examples also show you the extent of the public API that's meant to be used in applications, i.e. anything from the `nf` module. diff --git a/example/cnn.f90 b/example/cnn.f90 new file mode 100644 index 00000000..03c92b03 --- /dev/null +++ b/example/cnn.f90 @@ -0,0 +1,32 @@ +program cnn + + use nf, only: conv2d, dense, flatten, input, maxpool2d, network + + implicit none + type(network) :: net + real, allocatable :: x(:,:,:) + integer :: n + + print '("Creating a CNN and doing a forward pass")' + print '("(backward pass not implemented yet)")' + print '(60("="))' + + net = network([ & + input([3, 32, 32]), & + conv2d(filters=16, kernel_size=3, activation='relu'), & ! (16, 30, 30) + maxpool2d(pool_size=2), & ! (16, 15, 15) + conv2d(filters=32, kernel_size=3, activation='relu'), & ! (32, 13, 13) + maxpool2d(pool_size=2), & ! (32, 6, 6) + flatten(), & + dense(10) & + ]) + + ! Print a network summary to the screen + call net % print_info() + + allocate(x(3,32,32)) + call random_number(x) + + print *, 'Output:', net % output(x) + +end program cnn diff --git a/src/nf.f90 b/src/nf.f90 index 7271bd16..34de190b 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -2,7 +2,7 @@ module nf !! User API: everything an application needs to reference directly use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer - use nf_layer_constructors, only: conv2d, dense, input, maxpool2d + use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d use nf_network, only: network use nf_optimizers, only: sgd end module nf diff --git a/src/nf/nf_flatten_layer.f90 b/src/nf/nf_flatten_layer.f90 new file mode 100644 index 00000000..38e38098 --- /dev/null +++ b/src/nf/nf_flatten_layer.f90 @@ -0,0 +1,75 @@ +module nf_flatten_layer + + !! This module provides the concrete flatten layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_base_layer, only: base_layer + + implicit none + + private + public :: flatten_layer + + type, extends(base_layer) :: flatten_layer + + !! Concrete implementation of a flatten (3-d to 1-d) layer. + + integer, allocatable :: input_shape(:) + integer :: output_size + + real, allocatable :: gradient(:,:,:) + real, allocatable :: output(:) + + contains + + procedure :: backward + procedure :: forward + procedure :: init + + end type flatten_layer + + interface flatten_layer + elemental module function flatten_layer_cons() result(res) + !! This function returns the `flatten_layer` instance. + type(flatten_layer) :: res + !! `flatten_layer` instance + end function flatten_layer_cons + end interface flatten_layer + + interface + + pure module subroutine backward(self, input, gradient) + !! Apply the backward pass to the flatten layer. + !! This is a reshape operation from 1-d gradient to 3-d input. + class(flatten_layer), intent(in out) :: self + !! Flatten layer instance + real, intent(in) :: input(:,:,:) + !! Input from the previous layer + real, intent(in) :: gradient(:) + !! Gradient from the next layer + end subroutine backward + + pure module subroutine forward(self, input) + !! Propagate forward the layer. + !! Calling this subroutine updates the values of a few data components + !! of `flatten_layer` that are needed for the backward pass. + class(flatten_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:,:,:) + !! Input from the previous layer + end subroutine forward + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(flatten_layer), intent(in out) :: self + !! Dense layer instance + integer, intent(in) :: input_shape(:) + !! Shape of the input layer + end subroutine init + + end interface + +end module nf_flatten_layer diff --git a/src/nf/nf_flatten_layer_submodule.f90 b/src/nf/nf_flatten_layer_submodule.f90 new file mode 100644 index 00000000..d52e996d --- /dev/null +++ b/src/nf/nf_flatten_layer_submodule.f90 @@ -0,0 +1,48 @@ +submodule(nf_flatten_layer) nf_flatten_layer_submodule + + !! This module provides the concrete flatten layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_base_layer, only: base_layer + + implicit none + +contains + + elemental module function flatten_layer_cons() result(res) + type(flatten_layer) :: res + end function flatten_layer_cons + + + pure module subroutine backward(self, input, gradient) + class(flatten_layer), intent(in out) :: self + real, intent(in) :: input(:,:,:) + real, intent(in) :: gradient(:) + self % gradient = reshape(gradient, shape(input)) + end subroutine backward + + + pure module subroutine forward(self, input) + class(flatten_layer), intent(in out) :: self + real, intent(in) :: input(:,:,:) + self % output = pack(input, .true.) + end subroutine forward + + + module subroutine init(self, input_shape) + class(flatten_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % input_shape = input_shape + self % output_size = product(input_shape) + + allocate(self % gradient(input_shape(1), input_shape(2), input_shape(3))) + self % gradient = 0 + + allocate(self % output(self % output_size)) + self % output = 0 + + end subroutine init + +end submodule nf_flatten_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index eaab6d52..0222a8c5 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -7,7 +7,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, input, maxpool2d + public :: conv2d, dense, flatten, input, maxpool2d interface input @@ -84,6 +84,27 @@ pure module function dense(layer_size, activation) result(res) !! Resulting layer instance end function dense + pure module function flatten() result(res) + !! Flatten (3-d -> 1-d) layer constructor. + !! + !! Use this layer to chain layers with 3-d outputs to layers with 1-d + !! inputs. For example, to chain a `conv2d` or a `maxpool2d` layer + !! with a `dense` layer for a CNN for classification, place a `flatten` + !! layer between them. + !! + !! A flatten layer must not be the first layer in the network. + !! + !! Example: + !! + !! ``` + !! use nf, only :: flatten, layer + !! type(layer) :: flatten_layer + !! flatten_layer = flatten() + !! ``` + type(layer) :: res + !! Resulting layer instance + end function flatten + pure module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. !! diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 7fd0637a..8e991901 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -3,6 +3,7 @@ use nf_layer, only: layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer use nf_maxpool2d_layer, only: maxpool2d_layer @@ -11,26 +12,26 @@ contains - pure module function input1d(layer_size) result(res) - integer, intent(in) :: layer_size + pure module function conv2d(filters, kernel_size, activation) result(res) + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + character(*), intent(in), optional :: activation type(layer) :: res - res % name = 'input' - res % layer_shape = [layer_size] - res % input_layer_shape = [integer ::] - allocate(res % p, source=input1d_layer(layer_size)) - res % initialized = .true. - end function input1d + res % name = 'conv2d' - pure module function input3d(layer_shape) result(res) - integer, intent(in) :: layer_shape(3) - type(layer) :: res - res % name = 'input' - res % layer_shape = layer_shape - res % input_layer_shape = [integer ::] - allocate(res % p, source=input3d_layer(layer_shape)) - res % initialized = .true. - end function input3d + if (present(activation)) then + res % activation = activation + else + res % activation = 'sigmoid' + end if + + allocate( & + res % p, & + source=conv2d_layer(filters, kernel_size, res % activation) & + ) + + end function conv2d pure module function dense(layer_size, activation) result(res) @@ -52,27 +53,33 @@ pure module function dense(layer_size, activation) result(res) end function dense - pure module function conv2d(filters, kernel_size, activation) result(res) - integer, intent(in) :: filters - integer, intent(in) :: kernel_size - character(*), intent(in), optional :: activation + pure module function flatten() result(res) type(layer) :: res + res % name = 'flatten' + allocate(res % p, source=flatten_layer()) + end function flatten - res % name = 'conv2d' - if (present(activation)) then - res % activation = activation - else - res % activation = 'sigmoid' - end if - - allocate( & - res % p, & - source=conv2d_layer(filters, kernel_size, res % activation) & - ) + pure module function input1d(layer_size) result(res) + integer, intent(in) :: layer_size + type(layer) :: res + res % name = 'input' + res % layer_shape = [layer_size] + res % input_layer_shape = [integer ::] + allocate(res % p, source=input1d_layer(layer_size)) + res % initialized = .true. + end function input1d - end function conv2d + pure module function input3d(layer_shape) result(res) + integer, intent(in) :: layer_shape(3) + type(layer) :: res + res % name = 'input' + res % layer_shape = layer_shape + res % input_layer_shape = [integer ::] + allocate(res % p, source=input3d_layer(layer_shape)) + res % initialized = .true. + end function input3d pure module function maxpool2d(pool_size, stride) result(res) integer, intent(in) :: pool_size diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index fdbda9d2..dbffb6b1 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -2,39 +2,53 @@ use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer + use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer use nf_maxpool2d_layer, only: maxpool2d_layer - implicit none - contains pure module subroutine backward(self, previous, gradient) + implicit none class(layer), intent(in out) :: self class(layer), intent(in) :: previous real, intent(in) :: gradient(:) - ! Backward pass currently implemented only for dense layers - select type(this_layer => self % p); type is(dense_layer) - - ! Previous layer is the input layer to this layer. - ! For a backward pass on a dense layer, we must accept either an input layer - ! or another dense layer as input. - select type(prev_layer => previous % p) + ! Backward pass currently implemented only for dense and flatten layers + select type(this_layer => self % p) - type is(input1d_layer) - call this_layer % backward(prev_layer % output, gradient) type is(dense_layer) - call this_layer % backward(prev_layer % output, gradient) - end select + ! Upstream layers permitted: input1d, dense, flatten + select type(prev_layer => previous % p) + type is(input1d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(dense_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(flatten_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + + type is(flatten_layer) + + ! Downstream layers permitted: input3d, conv2d, maxpool2d + select type(prev_layer => previous % p) + type is(input3d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(conv2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(maxpool2d_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + end select end subroutine backward pure module subroutine forward(self, input) + implicit none class(layer), intent(in out) :: self class(layer), intent(in) :: input @@ -42,17 +56,19 @@ pure module subroutine forward(self, input) type is(dense_layer) - ! Input layers permitted: input1d, dense + ! Upstream layers permitted: input1d, dense, flatten select type(prev_layer => input % p) type is(input1d_layer) call this_layer % forward(prev_layer % output) type is(dense_layer) call this_layer % forward(prev_layer % output) + type is(flatten_layer) + call this_layer % forward(prev_layer % output) end select type is(conv2d_layer) - ! Input layers permitted: input3d, conv2d, maxpool2d + ! Upstream layers permitted: input3d, conv2d, maxpool2d select type(prev_layer => input % p) type is(input3d_layer) call this_layer % forward(prev_layer % output) @@ -64,7 +80,19 @@ pure module subroutine forward(self, input) type is(maxpool2d_layer) - ! Input layers permitted: input3d, conv2d, maxpool2d + ! Upstream layers permitted: input3d, conv2d, maxpool2d + select type(prev_layer => input % p) + type is(input3d_layer) + call this_layer % forward(prev_layer % output) + type is(conv2d_layer) + call this_layer % forward(prev_layer % output) + type is(maxpool2d_layer) + call this_layer % forward(prev_layer % output) + end select + + type is(flatten_layer) + + ! Upstream layers permitted: input3d, conv2d, maxpool2d select type(prev_layer => input % p) type is(input3d_layer) call this_layer % forward(prev_layer % output) @@ -80,6 +108,7 @@ end subroutine forward pure module subroutine get_output_1d(self, output) + implicit none class(layer), intent(in) :: self real, allocatable, intent(out) :: output(:) @@ -89,8 +118,10 @@ pure module subroutine get_output_1d(self, output) allocate(output, source=this_layer % output) type is(dense_layer) allocate(output, source=this_layer % output) + type is(flatten_layer) + allocate(output, source=this_layer % output) class default - error stop '1-d output can only be read from an input1d or dense layer.' + error stop '1-d output can only be read from an input1d, dense, or flatten layer.' end select @@ -98,6 +129,7 @@ end subroutine get_output_1d pure module subroutine get_output_3d(self, output) + implicit none class(layer), intent(in) :: self real, allocatable, intent(out) :: output(:,:,:) @@ -118,6 +150,7 @@ end subroutine get_output_3d impure elemental module subroutine init(self, input) + implicit none class(layer), intent(in out) :: self class(layer), intent(in) :: input @@ -128,13 +161,15 @@ impure elemental module subroutine init(self, input) call this_layer % init(input % layer_shape) end select - ! The shape of conv2d or maxpool2d layers is not known + ! The shape of conv2d, maxpool2d, or flatten layers is not known ! until we receive an input layer. select type(this_layer => self % p) type is(conv2d_layer) self % layer_shape = shape(this_layer % output) type is(maxpool2d_layer) self % layer_shape = shape(this_layer % output) + type is(flatten_layer) + self % layer_shape = shape(this_layer % output) end select self % input_layer_shape = input % layer_shape @@ -144,6 +179,7 @@ end subroutine init impure elemental module subroutine print_info(self) + implicit none class(layer), intent(in) :: self print '("Layer: ", a)', self % name print '(60("-"))' @@ -157,6 +193,7 @@ end subroutine print_info impure elemental module subroutine update(self, learning_rate) + implicit none class(layer), intent(in out) :: self real, intent(in) :: learning_rate diff --git a/src/nf/nf_maxpool2d_layer_submodule.f90 b/src/nf/nf_maxpool2d_layer_submodule.f90 index 68aa0152..3105f447 100644 --- a/src/nf/nf_maxpool2d_layer_submodule.f90 +++ b/src/nf/nf_maxpool2d_layer_submodule.f90 @@ -43,15 +43,19 @@ pure module subroutine forward(self, input) integer :: i, j, n integer :: ii, jj integer :: iend, jend + integer :: iextent, jextent integer :: maxloc_xy(2) input_width = size(input, dim=2) - input_height = size(input, dim=2) + input_height = size(input, dim=3) + + iextent = input_width - mod(input_width, self % stride) + jextent = input_height - mod(input_height, self % stride) ! Stride along the width and height of the input image stride_over_input: do concurrent( & - i = 1:input_width:self % stride, & - j = 1:input_height:self % stride & + i = 1:iextent:self % stride, & + j = 1:jextent:self % stride & ) ! Indices of the pooling layer diff --git a/src/nf/nf_network.f90 b/src/nf/nf_network.f90 index 0bac70c1..d9c90821 100644 --- a/src/nf/nf_network.f90 +++ b/src/nf/nf_network.f90 @@ -17,15 +17,17 @@ module nf_network contains procedure :: backward - procedure :: output procedure :: print_info procedure :: train procedure :: update procedure, private :: forward_1d procedure, private :: forward_3d + procedure, private :: output_1d + procedure, private :: output_3d generic :: forward => forward_1d, forward_3d + generic :: output => output_1d, output_3d end type network @@ -72,6 +74,30 @@ end subroutine forward_3d end interface forward + interface output + + module function output_1d(self, input) result(res) + !! Return the output of the network given the input 1-d array. + class(network), intent(in out) :: self + !! Network instance + real, intent(in) :: input(:) + !! Input data + real, allocatable :: res(:) + !! Output of the network + end function output_1d + + module function output_3d(self, input) result(res) + !! Return the output of the network given the input 3-d array. + class(network), intent(in out) :: self + !! Network instance + real, intent(in) :: input(:,:,:) + !! Input data + real, allocatable :: res(:) + !! Output of the network + end function output_3d + + end interface output + interface pure module subroutine backward(self, output) @@ -85,16 +111,6 @@ pure module subroutine backward(self, output) !! Output data end subroutine backward - module function output(self, input) result(res) - !! Return the output of the network given the input array. - class(network), intent(in out) :: self - !! Network instance - real, intent(in) :: input(:) - !! Input data - real, allocatable :: res(:) - !! Output of the network - end function output - module subroutine print_info(self) !! Prints a brief summary of the network and its layers to the screen. class(network), intent(in) :: self diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index f2a6c909..7d49bec8 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_network) nf_network_submodule use nf_dense_layer, only: dense_layer + use nf_flatten_layer, only: flatten_layer use nf_input1d_layer, only: input1d_layer use nf_input3d_layer, only: input3d_layer use nf_layer, only: layer @@ -30,11 +31,10 @@ module function network_cons(layers) result(res) !TODO Ensure that the layers are in allowed sequence: !TODO input1d -> dense !TODO dense -> dense - !TODO input3d -> conv2d - !TODO conv2d -> conv2d - !TODO conv2d -> maxpool2d - !TODO maxpool2d -> conv2d - !TODO conv2d -> flatten + !TODO input3d -> conv2d, maxpool2d, flatten + !TODO conv2d -> conv2d, maxpool2d, flatten + !TODO maxpool2d -> conv2d, maxpool2d, flatten + !TODO flatten -> dense res % layers = layers @@ -115,7 +115,7 @@ pure module subroutine forward_3d(self, input) end subroutine forward_3d - module function output(self, input) result(res) + module function output_1d(self, input) result(res) class(network), intent(in out) :: self real, intent(in) :: input(:) real, allocatable :: res(:) @@ -125,11 +125,34 @@ module function output(self, input) result(res) call self % forward(input) - select type(output_layer => self % layers(num_layers) % p); type is(dense_layer) - res = output_layer % output + select type(output_layer => self % layers(num_layers) % p) + type is(dense_layer) + res = output_layer % output + type is(flatten_layer) + res = output_layer % output end select - end function output + end function output_1d + + + module function output_3d(self, input) result(res) + class(network), intent(in out) :: self + real, intent(in) :: input(:,:,:) + real, allocatable :: res(:) + integer :: num_layers + + num_layers = size(self % layers) + + call self % forward(input) + + select type(output_layer => self % layers(num_layers) % p) + type is(dense_layer) + res = output_layer % output + type is(flatten_layer) + res = output_layer % output + end select + + end function output_3d module subroutine print_info(self) diff --git a/test/test_flatten_layer.f90 b/test/test_flatten_layer.f90 new file mode 100644 index 00000000..cc780acd --- /dev/null +++ b/test/test_flatten_layer.f90 @@ -0,0 +1,89 @@ +program test_flatten_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: dense, flatten, input, layer, network + use nf_flatten_layer, only: flatten_layer + use nf_input3d_layer, only: input3d_layer + + implicit none + + type(layer) :: test_layer, input_layer + type(network) :: net + real, allocatable :: input_data(:,:,:), gradient(:,:,:) + real, allocatable :: output(:) + logical :: ok = .true. + + test_layer = flatten() + + if (.not. test_layer % name == 'flatten') then + ok = .false. + write(stderr, '(a)') 'flatten layer has its name set correctly.. failed' + end if + + if (test_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'flatten layer is not initialized yet.. failed' + end if + + input_layer = input([1, 2, 2]) + call test_layer % init(input_layer) + + if (.not. test_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'flatten layer is now initialized.. failed' + end if + + if (.not. all(test_layer % layer_shape == [4])) then + ok = .false. + write(stderr, '(a)') 'flatten layer has an incorrect output shape.. failed' + end if + + ! Test forward pass - reshaping from 3-d to 1-d + + select type(this_layer => input_layer % p); type is(input3d_layer) + call this_layer % set(reshape(real([1, 2, 3, 4]), [1, 2, 2])) + end select + + call test_layer % forward(input_layer) + call test_layer % get_output(output) + + if (.not. all(output == [1, 2, 3, 4])) then + ok = .false. + write(stderr, '(a)') 'flatten layer correctly propagates forward.. failed' + end if + + ! Test backward pass - reshaping from 1-d to 3-d + + ! Calling backward() will set the values on the gradient component + ! input_layer is used only to determine shape + call test_layer % backward(input_layer, real([1, 2, 3, 4])) + + select type(this_layer => test_layer % p); type is(flatten_layer) + gradient = this_layer % gradient + end select + + if (.not. all(gradient == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then + ok = .false. + write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed' + end if + + net = network([ & + input([1, 28, 28]), & + flatten(), & + dense(10) & + ]) + + ! Test that the output layer receives 784 elements in the input + if (.not. all(net % layers(3) % input_layer_shape == [784])) then + ok = .false. + write(stderr, '(a)') 'flatten layer correctly chains input3d to dense.. failed' + end if + + if (ok) then + print '(a)', 'test_flatten_layer: All tests passed.' + else + write(stderr, '(a)') 'test_flatten_layer: One or more tests failed.' + stop 1 + end if + +end program test_flatten_layer