Skip to content

Commit

Permalink
Merge pull request #54 from BerkeleyLab/mini-batch
Browse files Browse the repository at this point in the history
Group input/output pairs into mini-batches
  • Loading branch information
rouson authored May 24, 2023
2 parents c780a8e + b7bb986 commit a2b7027
Show file tree
Hide file tree
Showing 13 changed files with 499 additions and 79 deletions.
116 changes: 116 additions & 0 deletions example/train-xor-gate.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt
program train_xor_gate
!! Define inference tests and procedures required for reporting results
use string_m, only : string_t
use trainable_engine_m, only : trainable_engine_t
use inputs_m, only : inputs_t
use outputs_m, only : outputs_t
use expected_outputs_m, only : expected_outputs_t
use matmul_m, only : matmul_t
use kind_parameters_m, only : rkind
use sigmoid_m, only : sigmoid_t
use input_output_pair_m, only : input_output_pair_t
use mini_batch_m, only : mini_batch_t
use file_m, only : file_t
use command_line_m, only : command_line_t
implicit none

real(rkind), parameter :: tolerance = 1.E-02_rkind, false = 0._rkind, true = 1._rkind
type(outputs_t), dimension(4) :: actual_output
type(expected_outputs_t), dimension(4) :: expected_outputs
type(trainable_engine_t) trainable_engine
type(command_line_t) command_line
type(string_t) base_name
type(mini_batch_t), allocatable :: mini_batches(:)
type(inputs_t), allocatable :: inputs(:)
character(len=5), parameter :: table_entry(*) = ["TT->F", "FT->T", "TF->T", "FF->F", "xor "]
integer i, m

base_name = string_t(command_line%flag_value("--base-name"))

if (len(base_name%string())==0) then
error stop new_line('a') // new_line('a') // &
'Usage: ./build/run-fpm.sh run --example train-xor-gate -- --base-name "<base-file-name>"'
end if

inputs = [ &
inputs_t([true,true]), inputs_t([false,true]), inputs_t([true,false]), inputs_t([false,false]) &
]
expected_outputs = [ &
expected_outputs_t([false]), expected_outputs_t([true]), expected_outputs_t([true]), expected_outputs_t([false]) &
]
print *,"Defining mini-batches, each containing input/output pairs corresponding to the four entries in the XOR truth table."
mini_batches = [(mini_batch_t( input_output_pair_t( inputs, expected_outputs ) ), m=1,2000000)]
print *,"Defining an initial trainable_engine_t neural network object."
trainable_engine = wide_single_layer_perceptron()
print *,"Training the neural network using the mini-batches. ___This could take a few minutes.___"
call trainable_engine%train(mini_batches, matmul_t())
associate(file_name => string_t(base_name%string() // ".json"))
print *,"Writing the network parameters to "
call output(trainable_engine, file_name)
end associate
print *,"Verifying that the network behaves as an exclusive-or (XOR) logic gate: "
actual_output = trainable_engine%infer(inputs, matmul_t())
if (any([(abs(actual_output(i)%outputs() - expected_outputs(i)%outputs()) < tolerance, i=1, size(actual_output))])) then
print *,"Yes!"
else
error stop &
"The trained network does not behave as a XOR gate. " // &
"Please report this issue at https://github.com/BerkeleyLab/inference-engine/issues."
end if

contains

subroutine output(engine, file_name)
type(trainable_engine_t), intent(in) :: engine
type(string_t), intent(in) :: file_name
type(file_t) json_file

json_file = trainable_engine%to_json()
print *, "Writing an inference_engine_t object to the file '"//file_name%string()//"' in JSON format."
call json_file%write_lines(file_name)
end subroutine

function wide_single_layer_perceptron() result(trainable_engine)
type(trainable_engine_t) trainable_engine
integer, parameter :: n_in = 2 ! number of inputs
integer, parameter :: n_out = 1 ! number of outputs
integer, parameter :: neurons = 3 ! number of neurons per layer
integer, parameter :: n_hidden = 1 ! number of hidden layers

trainable_engine = trainable_engine_t( &
metadata = [ &
string_t("Trainable XOR"), string_t("Damian Rouson"), string_t("2023-05-09"), string_t("sigmoid"), string_t("false") &
], &
input_weights = real(reshape([1,0,1,1,0,1], [n_in, neurons]), rkind), &
hidden_weights = reshape([real(rkind)::], [neurons,neurons,n_hidden-1]), &
output_weights = real(reshape([1,-2,1], [n_out, neurons]), rkind), &
biases = reshape([real(rkind):: 0.,-1.99,0.], [neurons, n_hidden]), &
output_biases = [real(rkind):: 0.], &
differentiable_activation_strategy = sigmoid_t() &
)
end function

function wide_perceptron() result(trainable_engine)
type(trainable_engine_t) trainable_engine
integer, parameter :: n_in = 2 ! number of inputs
integer, parameter :: n_out = 1 ! number of outputs
integer, parameter :: neurons = 24 ! number of neurons per layer
integer, parameter :: n_hidden = 1 ! number of hidden layers
integer n

trainable_engine = trainable_engine_t( &
metadata = [ &
string_t("Wide 1-layer perceptron"), string_t("Damian Rouson"), string_t("2023-05-24"), string_t("sigmoid"), string_t("false") &
], &
input_weights = real(reshape([([1,0,1,1,0,1], n=1,8 )], [n_in, neurons]), rkind), &
hidden_weights = reshape([real(rkind)::], [neurons,neurons,n_hidden-1]), &
output_weights = real(reshape([([1,-2,1], n=1,8)], [n_out, neurons]), rkind), &
biases = reshape([real(rkind):: [(0.,-1.99,0., n=1,8)] ], [neurons, n_hidden]), &
output_biases = [real(rkind):: 0.], &
differentiable_activation_strategy = sigmoid_t() &
)
end function

end program train_xor_gate
8 changes: 4 additions & 4 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ HDF5_LIB_PATH="`brew --prefix hdf5`/lib"
NETCDFF_LIB_PATH="`brew --prefix netcdf-fortran`/lib"

FPM_LD_FLAG=" -L$NETCDF_LIB_PATH -L$HDF5_LIB_PATH -L$NETCDFF_LIB_PATH"
FPM_FLAG="-fallow-argument-mismatch -ffree-line-length-none -L$NETCDF_LIB_PATH -L$HDF5_LIB_PATH"
FPM_FC=${FC:-"gfortran-12"}
FPM_CC=${CC:-"gcc-12"}
FPM_FLAG="-O3 -fallow-argument-mismatch -ffree-line-length-none -L$NETCDF_LIB_PATH -L$HDF5_LIB_PATH"
FPM_FC=${FC:-"gfortran-13"}
FPM_CC=${CC:-"gcc-13"}

mkdir -p build

Expand Down Expand Up @@ -106,7 +106,7 @@ export PKG_CONFIG_PATH
cp scripts/run-fpm.sh-header build/run-fpm.sh
RUN_FPM_SH="`realpath ./build/run-fpm.sh`"
echo "`which fpm` \$fpm_arguments \\" >> $RUN_FPM_SH
echo "--profile debug \\" >> $RUN_FPM_SH
echo "--profile release \\" >> $RUN_FPM_SH
echo "--c-compiler \"`pkg-config inference-engine --variable=INFERENCE_ENGINE_FPM_CC`\" \\" >> $RUN_FPM_SH
echo "--compiler \"`pkg-config inference-engine --variable=INFERENCE_ENGINE_FPM_FC`\" \\" >> $RUN_FPM_SH
echo "--flag \"`pkg-config inference-engine --variable=INFERENCE_ENGINE_FPM_FLAG`\" \\" >> $RUN_FPM_SH
Expand Down
7 changes: 3 additions & 4 deletions src/inference_engine/inference_engine_m_.f90
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module inference_engine_m_
use inputs_m, only : inputs_t
use outputs_m, only : outputs_t
use differentiable_activation_strategy_m, only :differentiable_activation_strategy_t
use network_increment_m, only : network_increment_t
implicit none

private
Expand Down Expand Up @@ -171,12 +172,10 @@ pure module function output_weights(self) result(w)
real(rkind), allocatable :: w(:,:)
end function

pure module subroutine increment(self, delta_w_in, delta_w_hidden, delta_w_out, delta_b_hidden, delta_b_out)
pure module subroutine increment(self, network_increment)
implicit none
class(inference_engine_t), intent(inout) :: self
real(rkind), intent(in), dimension(:,:,:) :: delta_w_hidden
real(rkind), intent(in), dimension(:,:) :: delta_w_in, delta_w_out, delta_b_hidden
real(rkind), intent(in), dimension(:) :: delta_b_out
type(network_increment_t), intent(in) :: network_increment
end subroutine

end interface
Expand Down
10 changes: 5 additions & 5 deletions src/inference_engine/inference_engine_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,11 @@ function get_key_value(line) result(value_)
end procedure

module procedure increment
self%input_weights_ = self%input_weights_ + delta_w_in
self%hidden_weights_ = self%hidden_weights_ + delta_w_hidden
self%output_weights_ = self%output_weights_ + delta_w_out
self%biases_ = self%biases_ + delta_b_hidden
self%output_biases_ = self%output_biases_ + delta_b_out
self%input_weights_ = self%input_weights_ + network_increment%delta_w_in()
self%hidden_weights_ = self%hidden_weights_ + network_increment%delta_w_hidden()
self%output_weights_ = self%output_weights_ + network_increment%delta_w_out()
self%biases_ = self%biases_ + network_increment%delta_b_hidden()
self%output_biases_ = self%output_biases_ + network_increment%delta_b_out()
end procedure

end submodule inference_engine_s
36 changes: 36 additions & 0 deletions src/inference_engine/mini_batch_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module mini_batch_m
use input_output_pair_m, only : input_output_pair_t
use kind_parameters_m, only : rkind
implicit none

private
public :: mini_batch_t

type mini_batch_t
private
type(input_output_pair_t), allocatable :: input_output_pairs_(:)
contains
procedure :: input_output_pairs
end type

interface mini_batch_t

pure module function construct(input_output_pairs) result(mini_batch)
implicit none
type(input_output_pair_t), intent(in) :: input_output_pairs(:)
type(mini_batch_t) mini_batch
end function

end interface

interface

pure module function input_output_pairs(self) result(my_input_output_pairs)
implicit none
class(mini_batch_t), intent(in) :: self
type(input_output_pair_t), allocatable :: my_input_output_pairs(:)
end function

end interface

end module mini_batch_m
14 changes: 14 additions & 0 deletions src/inference_engine/mini_batch_s.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
submodule(mini_batch_m) mini_batch_s
implicit none

contains

module procedure construct
mini_batch%input_output_pairs_ = input_output_pairs
end procedure

module procedure input_output_pairs
my_input_output_pairs = self%input_output_pairs_
end procedure

end submodule mini_batch_s
100 changes: 100 additions & 0 deletions src/inference_engine/network_increment_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
module network_increment_m
use kind_parameters_m, only : rkind
implicit none

private
public :: network_increment_t
public :: operator(.average.)

type network_increment_t
private
real(rkind), allocatable :: delta_w_in_(:,:)
real(rkind), allocatable :: delta_w_hidden_(:,:,:)
real(rkind), allocatable :: delta_w_out_(:,:)
real(rkind), allocatable :: delta_b_hidden_(:,:)
real(rkind), allocatable :: delta_b_out_(:)
contains
procedure, private :: add
generic :: operator(+) => add
procedure, private :: divide
generic :: operator(/) => divide
procedure :: delta_w_in
procedure :: delta_w_hidden
procedure :: delta_w_out
procedure :: delta_b_hidden
procedure :: delta_b_out
end type

interface network_increment_t

pure module function construct(delta_w_in, delta_w_hidden, delta_w_out, delta_b_hidden, delta_b_out) result(network_increment)
implicit none
real(rkind), intent(in) :: delta_w_in(:,:)
real(rkind), intent(in) :: delta_w_hidden(:,:,:)
real(rkind), intent(in) :: delta_w_out(:,:)
real(rkind), intent(in) :: delta_b_hidden(:,:)
real(rkind), intent(in) :: delta_b_out(:)
type(network_increment_t) network_increment
end function

end interface

interface operator(.average.)

pure module function average(rhs) result(average_increment)
implicit none
type(network_increment_t), intent(in) :: rhs(:)
type(network_increment_t) average_increment
end function

end interface

interface

pure module function add(lhs, rhs) result(total)
implicit none
class(network_increment_t), intent(in) :: lhs
type(network_increment_t), intent(in) :: rhs
type(network_increment_t) total
end function

pure module function divide(numerator, denominator) result(ratio)
implicit none
class(network_increment_t), intent(in) :: numerator
integer, intent(in) :: denominator
type(network_increment_t) ratio
end function

pure module function delta_w_in(self) result(my_delta_w_in)
implicit none
class(network_increment_t), intent(in) :: self
real(rkind), allocatable :: my_delta_w_in(:,:)
end function

pure module function delta_w_hidden(self) result(my_delta_w_hidden)
implicit none
class(network_increment_t), intent(in) :: self
real(rkind), allocatable :: my_delta_w_hidden(:,:,:)
end function

pure module function delta_w_out(self) result(my_delta_w_out)
implicit none
class(network_increment_t), intent(in) :: self
real(rkind), allocatable :: my_delta_w_out(:,:)
end function

pure module function delta_b_hidden(self) result(my_delta_b_hidden)
implicit none
class(network_increment_t), intent(in) :: self
real(rkind), allocatable :: my_delta_b_hidden(:,:)
end function

pure module function delta_b_out(self) result(my_delta_b_out)
implicit none
class(network_increment_t), intent(in) :: self
real(rkind), allocatable :: my_delta_b_out(:)
end function

end interface

end module network_increment_m
Loading

0 comments on commit a2b7027

Please sign in to comment.