Skip to content

Commit

Permalink
Merge pull request #20 from BerkeleyLab/asymmetric-network-test
Browse files Browse the repository at this point in the history
WIP: Asymmetric-network test of matmul-based inference
  • Loading branch information
rouson authored Nov 29, 2022
2 parents fff67fb + 30639f8 commit 6addc82
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/matmul_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
integer layer
do layer = 2, num_layers
neuron(:,layer) = &
activation_strategy%activation(matmul(hidden_weights(:,:,layer-1), neuron(:,layer-1)) + biases(:,layer))
activation_strategy%activation(matmul(transpose(hidden_weights(:,:,layer-1)), neuron(:,layer-1)) + biases(:,layer))
end do
end block
output = activation_strategy%activation(matmul(output_weights(:,:), neuron(:,num_layers)) + output_biases(:))
Expand Down
28 changes: 18 additions & 10 deletions test/asymmetric_engine_test_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module asymmetric_engine_test_m
use test_m, only : test_t
use test_result_m, only : test_result_t
use inference_engine_m, only : inference_engine_t
use inference_strategy_m, only : inference_strategy_t
use matmul_m, only : matmul_t
implicit none

Expand All @@ -28,16 +29,20 @@ function results() result(test_results)
type(test_result_t), allocatable :: test_results(:)

test_results = test_result_t( &
[ character(len=len("mapping (true,true) to false using the default ('do concurrent'/dot_product) inference method")) :: &
"mapping (true,true) to false using the default ('do concurrent'/dot_product) inference method", &
"mapping (false,true) to true using the default inference method", &
"mapping (true,false) to false using the default inference method", &
"mapping (false,false) to false using the default inference method" &
], [xor_and_2nd_input_truth_table()] &
[ character(len=len("mapping (true,true) to false using the default ('do concurrent'/dot_product) inference strategy")) :: &
"mapping (true,true) to false using the default ('do concurrent'/dot_product) inference strategy", &
"mapping (true,false) to false using the default inference strategy", &
"mapping (false,true) to true using the default inference strategy", &
"mapping (false,false) to false using the default inference strategy", &
"mapping (true,true) to false using the matmul inference strategy", &
"mapping (true,false) to false using the matmul inference strategy", &
"mapping (false,true) to true using the matmul inference strategy", &
"mapping (false,false) to false using the matmul inference strategy" &
], [xor_and_2nd_input_truth_table(), xor_and_2nd_input_truth_table(matmul_t())] &
)
end function

function xor_and_2nd_input_network() result(inference_engine)
function xor_and_2nd_input_network(inference_strategy) result(inference_engine)

type(inference_engine_t) inference_engine
integer, parameter :: n_in = 2 ! number of inputs
Expand All @@ -46,6 +51,7 @@ function xor_and_2nd_input_network() result(inference_engine)
integer, parameter :: n_hidden = 2 ! number of hidden layers
integer i, j
real xor_into_neuron_2(neurons,neurons,n_hidden-1)
class(inference_strategy_t), intent(in), optional :: inference_strategy
xor_into_neuron_2 = 0.
xor_into_neuron_2(1:3,2,1) = [1., -2., 1.]
xor_into_neuron_2(4,4,1) = 1.
Expand All @@ -55,16 +61,18 @@ function xor_and_2nd_input_network() result(inference_engine)
hidden_weights = xor_into_neuron_2, &
output_weights = real(reshape([0,1,0,1], [n_out, neurons])), &
biases = reshape([0.,-1.99,0.,0., 0.,0.,0.,0.], [neurons, n_hidden]), &
output_biases = [-1.] &
output_biases = [-1.], &
inference_strategy = inference_strategy &
)
end function

function xor_and_2nd_input_truth_table() result(test_passes)
function xor_and_2nd_input_truth_table(inference_strategy) result(test_passes)
logical, allocatable :: test_passes(:)

type(inference_engine_t) inference_engine
class(inference_strategy_t), intent(in), optional :: inference_strategy

inference_engine = xor_and_2nd_input_network()
inference_engine = xor_and_2nd_input_network(inference_strategy)

block
real, parameter :: tolerance = 1.E-08, false = 0., true = 1.
Expand Down
78 changes: 17 additions & 61 deletions test/inference_engine_test_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module inference_engine_test_m
use test_m, only : test_t
use test_result_m, only : test_result_t
use inference_engine_m, only : inference_engine_t
use inference_strategy_m, only : inference_strategy_t
use matmul_m, only : matmul_t
implicit none

Expand All @@ -29,41 +30,21 @@ function results() result(test_results)
type(test_result_t), allocatable :: test_results(:)

test_results = test_result_t( &
[ character(len=len("mapping (true,true) to false using the default ('do concurrent'/dot_product) inference method")) :: &
"mapping (true,true) to false using the default ('do concurrent'/dot_product) inference method", &
"mapping (false,true) to true using the default inference method", &
"mapping (true,false) to true using the default inference method", &
"mapping (false,false) to false using the default inference method", &
[ character(len=len("mapping (true,true) to false using the default ('do concurrent'/dot_product) inference strategy")) :: &
"mapping (true,true) to false using the default ('do concurrent'/dot_product) inference strategy", &
"mapping (true,false) to true using the default inference strategy", &
"mapping (false,true) to true using the default inference strategy", &
"mapping (false,false) to false using the default inference strategy", &
"writing and then reading itself to and from a file", &
"mapping (true,true) to false using `matmul`-based inference method", &
"mapping (false,true) to true using `matmul`-based inference method", &
"mapping (true,false) to true using `matmul`-based inference method", &
"mapping (false,false) to false using `matmul`-based inference method" &
], [xor_truth_table(), write_then_read(), matmul_inference()] &
"mapping (true,true) to false using `matmul`-based inference strategy", &
"mapping (true,false) to true using `matmul`-based inference strategy", &
"mapping (false,true) to true using `matmul`-based inference strategy", &
"mapping (false,false) to false using `matmul`-based inference strategy" &
], [xor_truth_table(), write_then_read(), xor_truth_table(matmul_t())] &
)
end function

function xor_network() result(inference_engine)

type(inference_engine_t) inference_engine
integer, parameter :: n_in = 2 ! number of inputs
integer, parameter :: n_out = 1 ! number of outputs
integer, parameter :: neurons = 3 ! number of neurons per layer
integer, parameter :: n_hidden = 2 ! number of hidden layers
integer i, j
integer, parameter :: identity(*,*,*) = &
reshape([((merge(1,0,i==j), i=1,neurons), j=1,neurons)], shape=[neurons,neurons,n_hidden-1])

inference_engine = inference_engine_t( &
input_weights = real(reshape([1,0,1,1,0,1], [n_in, neurons])), &
hidden_weights = real(identity), &
output_weights = real(reshape([1,-2,1], [n_out, neurons])), &
biases = reshape([0.,-1.99,0., 0.,0.,0.], [neurons, n_hidden]), &
output_biases = [0.] &
)
end function

function xor_matmul_network() result(inference_engine)
function xor_network(inference_strategy) result(inference_engine)

type(inference_engine_t) inference_engine
integer, parameter :: n_in = 2 ! number of inputs
Expand All @@ -73,14 +54,15 @@ function xor_matmul_network() result(inference_engine)
integer i, j
integer, parameter :: identity(*,*,*) = &
reshape([((merge(1,0,i==j), i=1,neurons), j=1,neurons)], shape=[neurons,neurons,n_hidden-1])
class(inference_strategy_t), intent(in), optional :: inference_strategy

inference_engine = inference_engine_t( &
input_weights = real(reshape([1,0,1,1,0,1], [n_in, neurons])), &
hidden_weights = real(identity), &
output_weights = real(reshape([1,-2,1], [n_out, neurons])), &
biases = reshape([0.,-1.99,0., 0.,0.,0.], [neurons, n_hidden]), &
output_biases = [0.], &
inference_strategy = matmul_t() &
inference_strategy = inference_strategy &
)
end function

Expand All @@ -105,39 +87,13 @@ function write_then_read() result(test_passes)
end block
end function

function xor_truth_table() result(test_passes)
function xor_truth_table(inference_strategy) result(test_passes)
logical, allocatable :: test_passes(:)
class(inference_strategy_t), intent(in), optional :: inference_strategy

type(inference_engine_t) inference_engine

inference_engine = xor_network()

block
real, parameter :: tolerance = 1.E-08, false = 0., true = 1.

associate( &
true_true => inference_engine%infer(input=[true,true]), &
true_false => inference_engine%infer(input=[true,false]), &
false_true => inference_engine%infer(input=[false,true]), &
false_false => inference_engine%infer(input=[false,false]) &
)
test_passes = [ &
size(true_true)==1 .and. abs(true_true(1) - false) < tolerance, &
size(true_false)==1 .and. abs(true_false(1) - true) < tolerance, &
size(false_true)==1 .and. abs(false_true(1) - true) < tolerance, &
size(false_false)==1 .and. abs(false_false(1) - false) < tolerance &
]
end associate
end block

end function

function matmul_inference() result(test_passes)
logical, allocatable :: test_passes(:)

type(inference_engine_t) inference_engine

inference_engine = xor_matmul_network()
inference_engine = xor_network(inference_strategy)

block
real, parameter :: tolerance = 1.E-08, false = 0., true = 1.
Expand Down

0 comments on commit 6addc82

Please sign in to comment.