Skip to content

Commit

Permalink
Merge pull request #25 from BerkeleyLab/inference-option
Browse files Browse the repository at this point in the history
Make matmul the default Inference strategy
  • Loading branch information
rouson authored Jan 13, 2023
2 parents afe6cf1 + 38919c1 commit 0546a37
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
3 changes: 2 additions & 1 deletion example/read-and-infer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ program read_and_infer
use inference_engine_m, only : inference_engine_t
use string_m, only : string_t
use sigmoid_m, only : sigmoid_t
use matmul_m, only : matmul_t
implicit none

type(inference_engine_t) inference_engine
Expand All @@ -25,7 +26,7 @@ program read_and_infer
end if

print *,"Defining an inference_engine_t object by reading the file '"//input_file_name//"'"
call inference_engine%read_network(string_t(input_file_name), sigmoid_t())
call inference_engine%read_network(string_t(input_file_name), sigmoid_t(), matmul_t())

print *,"num_outputs = ", inference_engine%num_outputs()
print *,"num_hidden_layers = ", inference_engine%num_hidden_layers()
Expand Down
3 changes: 2 additions & 1 deletion src/inference_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ pure module function construct &

interface

impure elemental module subroutine read_network(self, file_name, activation_strategy)
impure elemental module subroutine read_network(self, file_name, activation_strategy, inference_strategy)
implicit none
class(inference_engine_t), intent(out) :: self
type(string_t), intent(in) :: file_name
class(activation_strategy_t), intent(in), optional :: activation_strategy
class(inference_strategy_t), intent(in), optional :: inference_strategy
end subroutine

impure elemental module subroutine write_network(self, file_name)
Expand Down
10 changes: 7 additions & 3 deletions src/inference_engine_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
submodule(inference_engine_m) inference_engine_s
use assert_m, only : assert
use intrinsic_array_m, only : intrinsic_array_t
use concurrent_dot_products_m, only : concurrent_dot_products_t
use matmul_m, only : matmul_t
use step_m, only : step_t
implicit none

Expand All @@ -23,7 +23,7 @@
if (present(inference_strategy)) then
inference_engine%inference_strategy_ = inference_strategy
else
inference_engine%inference_strategy_ = concurrent_dot_products_t()
inference_engine%inference_strategy_ = matmul_t()
end if
end procedure

Expand Down Expand Up @@ -211,7 +211,11 @@ pure subroutine assert_consistent(self)
self%activation_strategy_ = step_t()
end if

self%inference_strategy_ = concurrent_dot_products_t()
if (present(inference_strategy)) then
self%inference_strategy_ = inference_strategy
else
self%inference_strategy_ = matmul_t()
end if

close(file_unit)

Expand Down

0 comments on commit 0546a37

Please sign in to comment.