Skip to content

Commit

Permalink
Merge pull request #38 from BerkeleyLab/read-metadata
Browse files Browse the repository at this point in the history
Feature: read metadata if present
  • Loading branch information
rouson authored Feb 14, 2023
2 parents 537969d + 204f0bb commit 9a1095e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/inference_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ module inference_engine_m
public :: inference_engine_t
public :: inputs_t
public :: outputs_t
public :: infer_from_inputs_object

type inputs_t
real(rkind), allocatable :: inputs_(:)
Expand All @@ -23,9 +22,15 @@ module inference_engine_m
real(rkind), allocatable :: outputs_(:)
end type

type metadata_t
character(len=:), allocatable :: modelName, modelAuthor, compilationDate
logical usingSkipConnections
end type

type inference_engine_t
!! Encapsulate the minimal information needed to performance inference
private
type(metadata_t) metadata_
real(rkind), allocatable :: input_weights_(:,:) ! weights applied to go from the inputs to first hidden layer
real(rkind), allocatable :: hidden_weights_(:,:,:) ! weights applied to go from one hidden layer to the next
real(rkind), allocatable :: output_weights_(:,:) ! weights applied to go from the final hidden layer to the outputs
Expand Down
73 changes: 68 additions & 5 deletions src/inference_engine_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,45 @@
type(layer_t) hidden_layers
type(neuron_t) output_neuron
real(rkind), allocatable :: hidden_weights(:,:,:)
integer l
character(len=:), allocatable :: quoted_value, line

lines = file_%lines()

call assert(adjustl(lines(1)%string())=="{", "from_json: expecting '{' for to start outermost object", lines(1)%string())
call assert(adjustl(lines(2)%string())=='"hidden_layers": [', 'from_json: expecting "hidden_layers": [', lines(2)%string())
l = 1
call assert(adjustl(lines(l)%string())=="{", "construct_from_json: expecting '{' to start outermost object", lines(l)%string())
l = 2
if (adjustl(lines(l)%string()) /= '"metadata": {') then
inference_engine%metadata_ = metadata_t(modelName="",modelAuthor="",compilationDate="", usingSkipConnections=.false.)
else
l = l + 1
inference_engine%metadata_%modelName = get_string_value(adjustl(lines(l)%string()), key="modelName")

l = l + 1
inference_engine%metadata_%modelAuthor = get_string_value(adjustl(lines(l)%string()), key="modelAuthor")

l = l + 1
inference_engine%metadata_%compilationDate = get_string_value(adjustl(lines(l)%string()), key="compilationDate")

l = l + 1
inference_engine%metadata_%usingSkipConnections = get_logical_value(adjustl(lines(l)%string()), key="usingSkipConnections")

l = l + 1
call assert(adjustl(lines(l)%string())=="},", "construct_from_json: expecting '},' to end metadata object", lines(l)%string())

l = l + 1
end if

call assert(adjustl(lines(l)%string())=='"hidden_layers": [', 'from_json: expecting "hidden_layers": [', lines(l)%string())
l = l + 1

block
integer, parameter :: first_layer_line=3, lines_per_neuron=4, bracket_lines_per_layer=2
integer, parameter :: lines_per_neuron=4, bracket_lines_per_layer=2
character(len=:), allocatable :: output_layer_line

hidden_layers = layer_t(lines, start=first_layer_line)
hidden_layers = layer_t(lines, start=l)

associate( output_layer_line_number => first_layer_line + lines_per_neuron*sum(hidden_layers%count_neurons()) &
associate( output_layer_line_number => l + lines_per_neuron*sum(hidden_layers%count_neurons()) &
+ bracket_lines_per_layer*hidden_layers%count_layers() + 1)

output_layer_line = lines(output_layer_line_number)%string()
Expand Down Expand Up @@ -113,6 +139,43 @@

call assert_consistent(inference_engine)

contains

pure function get_string_value(line, key) result(value_)
character(len=*), intent(in) :: line, key
character(len=:), allocatable :: value_

associate(opening_key_quotes => index(line, '"'), separator => index(line, ':'))
associate(closing_key_quotes => opening_key_quotes + index(line(opening_key_quotes+1:), '"'))
associate(unquoted_key => line(opening_key_quotes+1:closing_key_quotes-1), remainder => line(separator+1:))
call assert(unquoted_key == key,"construct_from_json(get_string_value): unquoted_key == key ", unquoted_key)
associate(opening_value_quotes => index(remainder, '"'))
associate(closing_value_quotes => opening_value_quotes + index(remainder(opening_value_quotes+1:), '"'))
value_ = remainder(opening_value_quotes+1:closing_value_quotes-1)
end associate
end associate
end associate
end associate
end associate
end function

pure function get_logical_value(line, key) result(value_)
character(len=*), intent(in) :: line, key
logical value_
character(len=:), allocatable :: remainder ! a gfortran bug prevents making this an association

associate(opening_key_quotes => index(line, '"'), separator => index(line, ':'))
associate(closing_key_quotes => opening_key_quotes + index(line(opening_key_quotes+1:), '"'))
associate(unquoted_key => line(opening_key_quotes+1:closing_key_quotes-1))
call assert(unquoted_key == key,"construct_from_json(get_string_value): unquoted_key == key ", unquoted_key)
remainder = adjustl(line(separator+1:))
call assert(any(remainder == ["true ", "false"]), "construct_from_json(get_logical_value): valid value", remainder)
value_ = remainder == "true"
end associate
end associate
end associate
end function

end procedure construct_from_json


Expand Down

0 comments on commit 9a1095e

Please sign in to comment.