-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from BerkeleyLab/custom-json-reader
Add custom JSON reader
- Loading branch information
Showing
10 changed files
with
879 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
! Copyright (c), The Regents of the University of California | ||
! Terms of use are as specified in LICENSE.txt | ||
program write_read_infer | ||
!! This program demonstrates how to write a neural network to a JSON file, | ||
!! read the same network from the written file, query the network object for | ||
!! some of its properties, print those properties, and use the network to | ||
!! perform inference. | ||
use command_line_m, only : command_line_t | ||
use inference_engine_m, only : inference_engine_t | ||
use string_m, only : string_t | ||
use matmul_m, only : matmul_t | ||
use step_m, only : step_t | ||
use file_m, only : file_t | ||
implicit none | ||
|
||
type(string_t) file_name | ||
type(command_line_t) command_line | ||
|
||
file_name = string_t(command_line%flag_value("--output-file")) | ||
|
||
if (len(file_name%string())==0) then | ||
error stop new_line('a') // new_line('a') // & | ||
'Usage: ./build/run-fpm.sh run --example write-read-infer -- --output-file "<file-name>"' | ||
end if | ||
|
||
call write_read_query_infer(file_name) | ||
|
||
contains | ||
|
||
subroutine write_read_query_infer(output_file_name) | ||
type(string_t), intent(in) :: output_file_name | ||
integer i, j | ||
integer, parameter :: num_inputs = 2, num_outputs = 1, num_neurons = 3, num_hidden_layers = 2 | ||
integer, parameter :: identity(*,*,*) = & | ||
reshape([((merge(1,0,i==j), i=1,num_neurons), j=1,num_neurons)], shape=[num_neurons,num_neurons,num_hidden_layers-1]) | ||
type(inference_engine_t) xor_network, inference_engine | ||
type(file_t) json_output_file, json_input_file | ||
|
||
print *, "Constructing an inference_engine_t neural-network object from scratch." | ||
xor_network = inference_engine_t( & | ||
input_weights = real(reshape([1,0,1,1,0,1], [num_inputs, num_neurons])), & | ||
hidden_weights = real(identity), & | ||
output_weights = real(reshape([1,-2,1], [num_outputs, num_neurons])), & | ||
biases = reshape([0.,-1.99,0., 0.,0.,0.], [num_neurons, num_hidden_layers]), & | ||
output_biases = [0.], & | ||
inference_strategy = matmul_t() & | ||
) | ||
print *, "Converting an inference_engine_t object to a file_t object." | ||
json_output_file = xor_network%to_json() | ||
|
||
print *, "Writing an inference_engine_t object to the file '"//output_file_name%string()//"' in JSON format." | ||
call json_output_file%write_lines(output_file_name) | ||
|
||
print *, "Reading an inference_engine_t object from the same JSON file '"//output_file_name%string()//"'." | ||
json_input_file = file_t(output_file_name) | ||
|
||
print *, "Constructing a new inference_engine_t object from the parameters read." | ||
inference_engine = inference_engine_t(json_input_file, step_t(), matmul_t()) | ||
|
||
print *, "Querying the new inference_engine_t object for several properties:" | ||
print *, "num_outputs = ", inference_engine%num_outputs() | ||
print *, "num_hidden_layers = ", inference_engine%num_hidden_layers() | ||
print *, "neurons_per_layer = ", inference_engine%neurons_per_layer() | ||
|
||
print *, "Performing inference:" | ||
print *, "inference_engine%infer([0.,1.]) =",inference_engine%infer([0.,1.]) | ||
print *, "Correct answer for the XOR neural network: ", 1. | ||
end subroutine write_read_query_infer | ||
|
||
end program |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
module file_m | ||
!! A representation of a file as an object | ||
use string_m, only : string_t | ||
|
||
private | ||
public :: file_t | ||
|
||
type file_t | ||
private | ||
type(string_t), allocatable :: lines_(:) | ||
contains | ||
procedure :: lines | ||
procedure :: write_lines | ||
end type | ||
|
||
interface file_t | ||
|
||
module function read_lines(file_name) result(file_object) | ||
implicit none | ||
type(string_t), intent(in) :: file_name | ||
type(file_t) file_object | ||
end function | ||
|
||
pure module function construct(lines) result(file_object) | ||
implicit none | ||
type(string_t), intent(in), allocatable :: lines(:) | ||
type(file_t) file_object | ||
end function | ||
|
||
end interface | ||
|
||
interface | ||
|
||
pure module function lines(self) result(my_lines) | ||
implicit none | ||
class(file_t), intent(in) :: self | ||
type(string_t), allocatable :: my_lines(:) | ||
end function | ||
|
||
impure elemental module subroutine write_lines(self, file_name) | ||
implicit none | ||
class(file_t), intent(in) :: self | ||
type(string_t), intent(in), optional :: file_name | ||
end subroutine | ||
|
||
end interface | ||
|
||
end module file_m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
submodule(file_m) file_s | ||
use iso_fortran_env, only : iostat_end, iostat_eor, output_unit | ||
use assert_m, only : assert | ||
implicit none | ||
|
||
contains | ||
|
||
module procedure construct | ||
file_object%lines_ = lines | ||
end procedure | ||
|
||
module procedure write_lines | ||
|
||
integer file_unit, io_status, l | ||
|
||
call assert(allocated(self%lines_), "file_t%write_lines: allocated(self%lines_)") | ||
|
||
if (present(file_name)) then | ||
open(newunit=file_unit, file=file_name%string(), form='formatted', status='unknown', iostat=io_status, action='write') | ||
call assert(io_status==0,"write_lines: io_status==0 after 'open' statement", file_name%string()) | ||
else | ||
file_unit = output_unit | ||
end if | ||
|
||
do l = 1, size(self%lines_) | ||
write(file_unit, *) self%lines_(l)%string() | ||
end do | ||
|
||
if (present(file_name)) close(file_unit) | ||
end procedure | ||
|
||
module procedure read_lines | ||
|
||
integer io_status, file_unit, line_num | ||
character(len=:), allocatable :: line | ||
integer, parameter :: max_message_length=128 | ||
character(len=max_message_length) error_message | ||
integer, allocatable :: lengths(:) | ||
|
||
open(newunit=file_unit, file=file_name%string(), form='formatted', status='old', iostat=io_status, action='read') | ||
call assert(io_status==0,"read_lines: io_status==0 after 'open' statement", file_name%string()) | ||
|
||
lengths = line_lengths(file_unit) | ||
|
||
associate(num_lines => size(lengths)) | ||
|
||
allocate(file_object%lines_(num_lines)) | ||
|
||
do line_num = 1, num_lines | ||
allocate(character(len=lengths(line_num)) :: line) | ||
read(file_unit, '(a)', iostat=io_status, iomsg=error_message) line | ||
call assert(io_status==0,"read_lines: io_status==0 after line read", error_message) | ||
file_object%lines_(line_num) = string_t(line) | ||
deallocate(line) | ||
end do | ||
|
||
end associate | ||
|
||
close(file_unit) | ||
|
||
contains | ||
|
||
function line_count(file_unit) result(num_lines) | ||
integer, intent(in) :: file_unit | ||
integer num_lines | ||
|
||
rewind(file_unit) | ||
num_lines = 0 | ||
do | ||
read(file_unit, *, iostat=io_status) | ||
if (io_status==iostat_end) exit | ||
num_lines = num_lines + 1 | ||
end do | ||
rewind(file_unit) | ||
end function | ||
|
||
function line_lengths(file_unit) result(lengths) | ||
integer, intent(in) :: file_unit | ||
integer, allocatable :: lengths(:) | ||
integer io_status | ||
character(len=1) c | ||
|
||
associate(num_lines => line_count(file_unit)) | ||
|
||
allocate(lengths(num_lines), source = 0) | ||
rewind(file_unit) | ||
|
||
do line_num = 1, num_lines | ||
do | ||
read(file_unit, '(a)', advance='no', iostat=io_status, iomsg=error_message) c | ||
if (io_status==iostat_eor) exit | ||
lengths(line_num) = lengths(line_num) + 1 | ||
end do | ||
end do | ||
|
||
rewind(file_unit) | ||
|
||
end associate | ||
end function | ||
|
||
end procedure | ||
|
||
module procedure lines | ||
my_lines = self%lines_ | ||
end procedure | ||
|
||
end submodule file_s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.