Skip to content

Commit

Permalink
Merge pull request #26 from BerkeleyLab/custom-json-reader
Browse files Browse the repository at this point in the history
Add custom JSON reader
  • Loading branch information
rouson authored Jan 23, 2023
2 parents 0546a37 + d783bd1 commit 790cff0
Show file tree
Hide file tree
Showing 10 changed files with 879 additions and 10 deletions.
70 changes: 70 additions & 0 deletions example/write-read-infer.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt
program write_read_infer
!! This program demonstrates how to write a neural network to a JSON file,
!! read the same network from the written file, query the network object for
!! some of its properties, print those properties, and use the network to
!! perform inference.
use command_line_m, only : command_line_t
use inference_engine_m, only : inference_engine_t
use string_m, only : string_t
use matmul_m, only : matmul_t
use step_m, only : step_t
use file_m, only : file_t
implicit none

type(string_t) file_name
type(command_line_t) command_line

file_name = string_t(command_line%flag_value("--output-file"))

if (len(file_name%string())==0) then
error stop new_line('a') // new_line('a') // &
'Usage: ./build/run-fpm.sh run --example write-read-infer -- --output-file "<file-name>"'
end if

call write_read_query_infer(file_name)

contains

subroutine write_read_query_infer(output_file_name)
type(string_t), intent(in) :: output_file_name
integer i, j
integer, parameter :: num_inputs = 2, num_outputs = 1, num_neurons = 3, num_hidden_layers = 2
integer, parameter :: identity(*,*,*) = &
reshape([((merge(1,0,i==j), i=1,num_neurons), j=1,num_neurons)], shape=[num_neurons,num_neurons,num_hidden_layers-1])
type(inference_engine_t) xor_network, inference_engine
type(file_t) json_output_file, json_input_file

print *, "Constructing an inference_engine_t neural-network object from scratch."
xor_network = inference_engine_t( &
input_weights = real(reshape([1,0,1,1,0,1], [num_inputs, num_neurons])), &
hidden_weights = real(identity), &
output_weights = real(reshape([1,-2,1], [num_outputs, num_neurons])), &
biases = reshape([0.,-1.99,0., 0.,0.,0.], [num_neurons, num_hidden_layers]), &
output_biases = [0.], &
inference_strategy = matmul_t() &
)
print *, "Converting an inference_engine_t object to a file_t object."
json_output_file = xor_network%to_json()

print *, "Writing an inference_engine_t object to the file '"//output_file_name%string()//"' in JSON format."
call json_output_file%write_lines(output_file_name)

print *, "Reading an inference_engine_t object from the same JSON file '"//output_file_name%string()//"'."
json_input_file = file_t(output_file_name)

print *, "Constructing a new inference_engine_t object from the parameters read."
inference_engine = inference_engine_t(json_input_file, step_t(), matmul_t())

print *, "Querying the new inference_engine_t object for several properties:"
print *, "num_outputs = ", inference_engine%num_outputs()
print *, "num_hidden_layers = ", inference_engine%num_hidden_layers()
print *, "neurons_per_layer = ", inference_engine%neurons_per_layer()

print *, "Performing inference:"
print *, "inference_engine%infer([0.,1.]) =",inference_engine%infer([0.,1.])
print *, "Correct answer for the XOR neural network: ", 1.
end subroutine write_read_query_infer

end program
48 changes: 48 additions & 0 deletions src/file_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module file_m
!! A representation of a file as an object
use string_m, only : string_t

private
public :: file_t

type file_t
private
type(string_t), allocatable :: lines_(:)
contains
procedure :: lines
procedure :: write_lines
end type

interface file_t

module function read_lines(file_name) result(file_object)
implicit none
type(string_t), intent(in) :: file_name
type(file_t) file_object
end function

pure module function construct(lines) result(file_object)
implicit none
type(string_t), intent(in), allocatable :: lines(:)
type(file_t) file_object
end function

end interface

interface

pure module function lines(self) result(my_lines)
implicit none
class(file_t), intent(in) :: self
type(string_t), allocatable :: my_lines(:)
end function

impure elemental module subroutine write_lines(self, file_name)
implicit none
class(file_t), intent(in) :: self
type(string_t), intent(in), optional :: file_name
end subroutine

end interface

end module file_m
107 changes: 107 additions & 0 deletions src/file_s.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
submodule(file_m) file_s
use iso_fortran_env, only : iostat_end, iostat_eor, output_unit
use assert_m, only : assert
implicit none

contains

module procedure construct
file_object%lines_ = lines
end procedure

module procedure write_lines

integer file_unit, io_status, l

call assert(allocated(self%lines_), "file_t%write_lines: allocated(self%lines_)")

if (present(file_name)) then
open(newunit=file_unit, file=file_name%string(), form='formatted', status='unknown', iostat=io_status, action='write')
call assert(io_status==0,"write_lines: io_status==0 after 'open' statement", file_name%string())
else
file_unit = output_unit
end if

do l = 1, size(self%lines_)
write(file_unit, *) self%lines_(l)%string()
end do

if (present(file_name)) close(file_unit)
end procedure

module procedure read_lines

integer io_status, file_unit, line_num
character(len=:), allocatable :: line
integer, parameter :: max_message_length=128
character(len=max_message_length) error_message
integer, allocatable :: lengths(:)

open(newunit=file_unit, file=file_name%string(), form='formatted', status='old', iostat=io_status, action='read')
call assert(io_status==0,"read_lines: io_status==0 after 'open' statement", file_name%string())

lengths = line_lengths(file_unit)

associate(num_lines => size(lengths))

allocate(file_object%lines_(num_lines))

do line_num = 1, num_lines
allocate(character(len=lengths(line_num)) :: line)
read(file_unit, '(a)', iostat=io_status, iomsg=error_message) line
call assert(io_status==0,"read_lines: io_status==0 after line read", error_message)
file_object%lines_(line_num) = string_t(line)
deallocate(line)
end do

end associate

close(file_unit)

contains

function line_count(file_unit) result(num_lines)
integer, intent(in) :: file_unit
integer num_lines

rewind(file_unit)
num_lines = 0
do
read(file_unit, *, iostat=io_status)
if (io_status==iostat_end) exit
num_lines = num_lines + 1
end do
rewind(file_unit)
end function

function line_lengths(file_unit) result(lengths)
integer, intent(in) :: file_unit
integer, allocatable :: lengths(:)
integer io_status
character(len=1) c

associate(num_lines => line_count(file_unit))

allocate(lengths(num_lines), source = 0)
rewind(file_unit)

do line_num = 1, num_lines
do
read(file_unit, '(a)', advance='no', iostat=io_status, iomsg=error_message) c
if (io_status==iostat_eor) exit
lengths(line_num) = lengths(line_num) + 1
end do
end do

rewind(file_unit)

end associate
end function

end procedure

module procedure lines
my_lines = self%lines_
end procedure

end submodule file_s
16 changes: 16 additions & 0 deletions src/inference_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module inference_engine_m
use string_m, only : string_t
use inference_strategy_m, only : inference_strategy_t
use activation_strategy_m, only : activation_strategy_t
use file_m, only : file_t
implicit none

private
Expand All @@ -22,6 +23,7 @@ module inference_engine_m
class(inference_strategy_t), allocatable :: inference_strategy_
contains
procedure :: read_network
procedure :: to_json
procedure :: write_network
procedure :: infer
procedure :: num_inputs
Expand All @@ -47,10 +49,24 @@ pure module function construct &
type(inference_engine_t) inference_engine
end function

impure elemental module function from_json(file_, activation_strategy, inference_strategy) result(inference_engine)
implicit none
type(file_t), intent(in) :: file_
class(activation_strategy_t), intent(in), optional :: activation_strategy
class(inference_strategy_t), intent(in), optional :: inference_strategy
type(inference_engine_t) inference_engine
end function

end interface

interface

impure elemental module function to_json(self) result(json_file)
implicit none
class(inference_engine_t), intent(in) :: self
type(file_t) json_file
end function

impure elemental module subroutine read_network(self, file_name, activation_strategy, inference_strategy)
implicit none
class(inference_engine_t), intent(out) :: self
Expand Down
Loading

0 comments on commit 790cff0

Please sign in to comment.