-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #54 from BerkeleyLab/mini-batch
Group input/output pairs into mini-batches
- Loading branch information
Showing
13 changed files
with
499 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
! Copyright (c), The Regents of the University of California | ||
! Terms of use are as specified in LICENSE.txt | ||
program train_xor_gate | ||
!! Define inference tests and procedures required for reporting results | ||
use string_m, only : string_t | ||
use trainable_engine_m, only : trainable_engine_t | ||
use inputs_m, only : inputs_t | ||
use outputs_m, only : outputs_t | ||
use expected_outputs_m, only : expected_outputs_t | ||
use matmul_m, only : matmul_t | ||
use kind_parameters_m, only : rkind | ||
use sigmoid_m, only : sigmoid_t | ||
use input_output_pair_m, only : input_output_pair_t | ||
use mini_batch_m, only : mini_batch_t | ||
use file_m, only : file_t | ||
use command_line_m, only : command_line_t | ||
implicit none | ||
|
||
real(rkind), parameter :: tolerance = 1.E-02_rkind, false = 0._rkind, true = 1._rkind | ||
type(outputs_t), dimension(4) :: actual_output | ||
type(expected_outputs_t), dimension(4) :: expected_outputs | ||
type(trainable_engine_t) trainable_engine | ||
type(command_line_t) command_line | ||
type(string_t) base_name | ||
type(mini_batch_t), allocatable :: mini_batches(:) | ||
type(inputs_t), allocatable :: inputs(:) | ||
character(len=5), parameter :: table_entry(*) = ["TT->F", "FT->T", "TF->T", "FF->F", "xor "] | ||
integer i, m | ||
|
||
base_name = string_t(command_line%flag_value("--base-name")) | ||
|
||
if (len(base_name%string())==0) then | ||
error stop new_line('a') // new_line('a') // & | ||
'Usage: ./build/run-fpm.sh run --example train-xor-gate -- --base-name "<base-file-name>"' | ||
end if | ||
|
||
inputs = [ & | ||
inputs_t([true,true]), inputs_t([false,true]), inputs_t([true,false]), inputs_t([false,false]) & | ||
] | ||
expected_outputs = [ & | ||
expected_outputs_t([false]), expected_outputs_t([true]), expected_outputs_t([true]), expected_outputs_t([false]) & | ||
] | ||
print *,"Defining mini-batches, each containing input/output pairs corresponding to the four entries in the XOR truth table." | ||
mini_batches = [(mini_batch_t( input_output_pair_t( inputs, expected_outputs ) ), m=1,2000000)] | ||
print *,"Defining an initial trainable_engine_t neural network object." | ||
trainable_engine = wide_single_layer_perceptron() | ||
print *,"Training the neural network using the mini-batches. ___This could take a few minutes.___" | ||
call trainable_engine%train(mini_batches, matmul_t()) | ||
associate(file_name => string_t(base_name%string() // ".json")) | ||
print *,"Writing the network parameters to " | ||
call output(trainable_engine, file_name) | ||
end associate | ||
print *,"Verifying that the network behaves as an exclusive-or (XOR) logic gate: " | ||
actual_output = trainable_engine%infer(inputs, matmul_t()) | ||
if (any([(abs(actual_output(i)%outputs() - expected_outputs(i)%outputs()) < tolerance, i=1, size(actual_output))])) then | ||
print *,"Yes!" | ||
else | ||
error stop & | ||
"The trained network does not behave as a XOR gate. " // & | ||
"Please report this issue at https://github.com/BerkeleyLab/inference-engine/issues." | ||
end if | ||
|
||
contains | ||
|
||
subroutine output(engine, file_name) | ||
type(trainable_engine_t), intent(in) :: engine | ||
type(string_t), intent(in) :: file_name | ||
type(file_t) json_file | ||
|
||
json_file = trainable_engine%to_json() | ||
print *, "Writing an inference_engine_t object to the file '"//file_name%string()//"' in JSON format." | ||
call json_file%write_lines(file_name) | ||
end subroutine | ||
|
||
function wide_single_layer_perceptron() result(trainable_engine) | ||
type(trainable_engine_t) trainable_engine | ||
integer, parameter :: n_in = 2 ! number of inputs | ||
integer, parameter :: n_out = 1 ! number of outputs | ||
integer, parameter :: neurons = 3 ! number of neurons per layer | ||
integer, parameter :: n_hidden = 1 ! number of hidden layers | ||
|
||
trainable_engine = trainable_engine_t( & | ||
metadata = [ & | ||
string_t("Trainable XOR"), string_t("Damian Rouson"), string_t("2023-05-09"), string_t("sigmoid"), string_t("false") & | ||
], & | ||
input_weights = real(reshape([1,0,1,1,0,1], [n_in, neurons]), rkind), & | ||
hidden_weights = reshape([real(rkind)::], [neurons,neurons,n_hidden-1]), & | ||
output_weights = real(reshape([1,-2,1], [n_out, neurons]), rkind), & | ||
biases = reshape([real(rkind):: 0.,-1.99,0.], [neurons, n_hidden]), & | ||
output_biases = [real(rkind):: 0.], & | ||
differentiable_activation_strategy = sigmoid_t() & | ||
) | ||
end function | ||
|
||
function wide_perceptron() result(trainable_engine) | ||
type(trainable_engine_t) trainable_engine | ||
integer, parameter :: n_in = 2 ! number of inputs | ||
integer, parameter :: n_out = 1 ! number of outputs | ||
integer, parameter :: neurons = 24 ! number of neurons per layer | ||
integer, parameter :: n_hidden = 1 ! number of hidden layers | ||
integer n | ||
|
||
trainable_engine = trainable_engine_t( & | ||
metadata = [ & | ||
string_t("Wide 1-layer perceptron"), string_t("Damian Rouson"), string_t("2023-05-24"), string_t("sigmoid"), string_t("false") & | ||
], & | ||
input_weights = real(reshape([([1,0,1,1,0,1], n=1,8 )], [n_in, neurons]), rkind), & | ||
hidden_weights = reshape([real(rkind)::], [neurons,neurons,n_hidden-1]), & | ||
output_weights = real(reshape([([1,-2,1], n=1,8)], [n_out, neurons]), rkind), & | ||
biases = reshape([real(rkind):: [(0.,-1.99,0., n=1,8)] ], [neurons, n_hidden]), & | ||
output_biases = [real(rkind):: 0.], & | ||
differentiable_activation_strategy = sigmoid_t() & | ||
) | ||
end function | ||
|
||
end program train_xor_gate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
module mini_batch_m | ||
use input_output_pair_m, only : input_output_pair_t | ||
use kind_parameters_m, only : rkind | ||
implicit none | ||
|
||
private | ||
public :: mini_batch_t | ||
|
||
type mini_batch_t | ||
private | ||
type(input_output_pair_t), allocatable :: input_output_pairs_(:) | ||
contains | ||
procedure :: input_output_pairs | ||
end type | ||
|
||
interface mini_batch_t | ||
|
||
pure module function construct(input_output_pairs) result(mini_batch) | ||
implicit none | ||
type(input_output_pair_t), intent(in) :: input_output_pairs(:) | ||
type(mini_batch_t) mini_batch | ||
end function | ||
|
||
end interface | ||
|
||
interface | ||
|
||
pure module function input_output_pairs(self) result(my_input_output_pairs) | ||
implicit none | ||
class(mini_batch_t), intent(in) :: self | ||
type(input_output_pair_t), allocatable :: my_input_output_pairs(:) | ||
end function | ||
|
||
end interface | ||
|
||
end module mini_batch_m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
submodule(mini_batch_m) mini_batch_s | ||
implicit none | ||
|
||
contains | ||
|
||
module procedure construct | ||
mini_batch%input_output_pairs_ = input_output_pairs | ||
end procedure | ||
|
||
module procedure input_output_pairs | ||
my_input_output_pairs = self%input_output_pairs_ | ||
end procedure | ||
|
||
end submodule mini_batch_s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
module network_increment_m | ||
use kind_parameters_m, only : rkind | ||
implicit none | ||
|
||
private | ||
public :: network_increment_t | ||
public :: operator(.average.) | ||
|
||
type network_increment_t | ||
private | ||
real(rkind), allocatable :: delta_w_in_(:,:) | ||
real(rkind), allocatable :: delta_w_hidden_(:,:,:) | ||
real(rkind), allocatable :: delta_w_out_(:,:) | ||
real(rkind), allocatable :: delta_b_hidden_(:,:) | ||
real(rkind), allocatable :: delta_b_out_(:) | ||
contains | ||
procedure, private :: add | ||
generic :: operator(+) => add | ||
procedure, private :: divide | ||
generic :: operator(/) => divide | ||
procedure :: delta_w_in | ||
procedure :: delta_w_hidden | ||
procedure :: delta_w_out | ||
procedure :: delta_b_hidden | ||
procedure :: delta_b_out | ||
end type | ||
|
||
interface network_increment_t | ||
|
||
pure module function construct(delta_w_in, delta_w_hidden, delta_w_out, delta_b_hidden, delta_b_out) result(network_increment) | ||
implicit none | ||
real(rkind), intent(in) :: delta_w_in(:,:) | ||
real(rkind), intent(in) :: delta_w_hidden(:,:,:) | ||
real(rkind), intent(in) :: delta_w_out(:,:) | ||
real(rkind), intent(in) :: delta_b_hidden(:,:) | ||
real(rkind), intent(in) :: delta_b_out(:) | ||
type(network_increment_t) network_increment | ||
end function | ||
|
||
end interface | ||
|
||
interface operator(.average.) | ||
|
||
pure module function average(rhs) result(average_increment) | ||
implicit none | ||
type(network_increment_t), intent(in) :: rhs(:) | ||
type(network_increment_t) average_increment | ||
end function | ||
|
||
end interface | ||
|
||
interface | ||
|
||
pure module function add(lhs, rhs) result(total) | ||
implicit none | ||
class(network_increment_t), intent(in) :: lhs | ||
type(network_increment_t), intent(in) :: rhs | ||
type(network_increment_t) total | ||
end function | ||
|
||
pure module function divide(numerator, denominator) result(ratio) | ||
implicit none | ||
class(network_increment_t), intent(in) :: numerator | ||
integer, intent(in) :: denominator | ||
type(network_increment_t) ratio | ||
end function | ||
|
||
pure module function delta_w_in(self) result(my_delta_w_in) | ||
implicit none | ||
class(network_increment_t), intent(in) :: self | ||
real(rkind), allocatable :: my_delta_w_in(:,:) | ||
end function | ||
|
||
pure module function delta_w_hidden(self) result(my_delta_w_hidden) | ||
implicit none | ||
class(network_increment_t), intent(in) :: self | ||
real(rkind), allocatable :: my_delta_w_hidden(:,:,:) | ||
end function | ||
|
||
pure module function delta_w_out(self) result(my_delta_w_out) | ||
implicit none | ||
class(network_increment_t), intent(in) :: self | ||
real(rkind), allocatable :: my_delta_w_out(:,:) | ||
end function | ||
|
||
pure module function delta_b_hidden(self) result(my_delta_b_hidden) | ||
implicit none | ||
class(network_increment_t), intent(in) :: self | ||
real(rkind), allocatable :: my_delta_b_hidden(:,:) | ||
end function | ||
|
||
pure module function delta_b_out(self) result(my_delta_b_out) | ||
implicit none | ||
class(network_increment_t), intent(in) :: self | ||
real(rkind), allocatable :: my_delta_b_out(:) | ||
end function | ||
|
||
end interface | ||
|
||
end module network_increment_m |
Oops, something went wrong.