Skip to content

Commit 204f0bb

Browse files
committedFeb 14, 2023
feat(construct_from_json):read metadata if present
1 parent 537969d commit 204f0bb

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed
 

‎src/inference_engine_m.f90

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ module inference_engine_m
1313
public :: inference_engine_t
1414
public :: inputs_t
1515
public :: outputs_t
16-
public :: infer_from_inputs_object
1716

1817
type inputs_t
1918
real(rkind), allocatable :: inputs_(:)
@@ -23,9 +22,15 @@ module inference_engine_m
2322
real(rkind), allocatable :: outputs_(:)
2423
end type
2524

25+
type metadata_t
26+
character(len=:), allocatable :: modelName, modelAuthor, compilationDate
27+
logical usingSkipConnections
28+
end type
29+
2630
type inference_engine_t
2731
!! Encapsulate the minimal information needed to performance inference
2832
private
33+
type(metadata_t) metadata_
2934
real(rkind), allocatable :: input_weights_(:,:) ! weights applied to go from the inputs to first hidden layer
3035
real(rkind), allocatable :: hidden_weights_(:,:,:) ! weights applied to go from one hidden layer to the next
3136
real(rkind), allocatable :: output_weights_(:,:) ! weights applied to go from the final hidden layer to the outputs

‎src/inference_engine_s.f90

+68-5
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,45 @@
5252
type(layer_t) hidden_layers
5353
type(neuron_t) output_neuron
5454
real(rkind), allocatable :: hidden_weights(:,:,:)
55+
integer l
56+
character(len=:), allocatable :: quoted_value, line
5557

5658
lines = file_%lines()
5759

58-
call assert(adjustl(lines(1)%string())=="{", "from_json: expecting '{' for to start outermost object", lines(1)%string())
59-
call assert(adjustl(lines(2)%string())=='"hidden_layers": [', 'from_json: expecting "hidden_layers": [', lines(2)%string())
60+
l = 1
61+
call assert(adjustl(lines(l)%string())=="{", "construct_from_json: expecting '{' to start outermost object", lines(l)%string())
62+
l = 2
63+
if (adjustl(lines(l)%string()) /= '"metadata": {') then
64+
inference_engine%metadata_ = metadata_t(modelName="",modelAuthor="",compilationDate="", usingSkipConnections=.false.)
65+
else
66+
l = l + 1
67+
inference_engine%metadata_%modelName = get_string_value(adjustl(lines(l)%string()), key="modelName")
68+
69+
l = l + 1
70+
inference_engine%metadata_%modelAuthor = get_string_value(adjustl(lines(l)%string()), key="modelAuthor")
71+
72+
l = l + 1
73+
inference_engine%metadata_%compilationDate = get_string_value(adjustl(lines(l)%string()), key="compilationDate")
74+
75+
l = l + 1
76+
inference_engine%metadata_%usingSkipConnections = get_logical_value(adjustl(lines(l)%string()), key="usingSkipConnections")
77+
78+
l = l + 1
79+
call assert(adjustl(lines(l)%string())=="},", "construct_from_json: expecting '},' to end metadata object", lines(l)%string())
80+
81+
l = l + 1
82+
end if
83+
84+
call assert(adjustl(lines(l)%string())=='"hidden_layers": [', 'from_json: expecting "hidden_layers": [', lines(l)%string())
85+
l = l + 1
6086

6187
block
62-
integer, parameter :: first_layer_line=3, lines_per_neuron=4, bracket_lines_per_layer=2
88+
integer, parameter :: lines_per_neuron=4, bracket_lines_per_layer=2
6389
character(len=:), allocatable :: output_layer_line
6490

65-
hidden_layers = layer_t(lines, start=first_layer_line)
91+
hidden_layers = layer_t(lines, start=l)
6692

67-
associate( output_layer_line_number => first_layer_line + lines_per_neuron*sum(hidden_layers%count_neurons()) &
93+
associate( output_layer_line_number => l + lines_per_neuron*sum(hidden_layers%count_neurons()) &
6894
+ bracket_lines_per_layer*hidden_layers%count_layers() + 1)
6995

7096
output_layer_line = lines(output_layer_line_number)%string()
@@ -113,6 +139,43 @@
113139

114140
call assert_consistent(inference_engine)
115141

142+
contains
143+
144+
pure function get_string_value(line, key) result(value_)
145+
character(len=*), intent(in) :: line, key
146+
character(len=:), allocatable :: value_
147+
148+
associate(opening_key_quotes => index(line, '"'), separator => index(line, ':'))
149+
associate(closing_key_quotes => opening_key_quotes + index(line(opening_key_quotes+1:), '"'))
150+
associate(unquoted_key => line(opening_key_quotes+1:closing_key_quotes-1), remainder => line(separator+1:))
151+
call assert(unquoted_key == key,"construct_from_json(get_string_value): unquoted_key == key ", unquoted_key)
152+
associate(opening_value_quotes => index(remainder, '"'))
153+
associate(closing_value_quotes => opening_value_quotes + index(remainder(opening_value_quotes+1:), '"'))
154+
value_ = remainder(opening_value_quotes+1:closing_value_quotes-1)
155+
end associate
156+
end associate
157+
end associate
158+
end associate
159+
end associate
160+
end function
161+
162+
pure function get_logical_value(line, key) result(value_)
163+
character(len=*), intent(in) :: line, key
164+
logical value_
165+
character(len=:), allocatable :: remainder ! a gfortran bug prevents making this an association
166+
167+
associate(opening_key_quotes => index(line, '"'), separator => index(line, ':'))
168+
associate(closing_key_quotes => opening_key_quotes + index(line(opening_key_quotes+1:), '"'))
169+
associate(unquoted_key => line(opening_key_quotes+1:closing_key_quotes-1))
170+
call assert(unquoted_key == key,"construct_from_json(get_string_value): unquoted_key == key ", unquoted_key)
171+
remainder = adjustl(line(separator+1:))
172+
call assert(any(remainder == ["true ", "false"]), "construct_from_json(get_logical_value): valid value", remainder)
173+
value_ = remainder == "true"
174+
end associate
175+
end associate
176+
end associate
177+
end function
178+
116179
end procedure construct_from_json
117180

118181

0 commit comments

Comments
 (0)
Please sign in to comment.