Skip to content

Commit

Permalink
feat(hyperparameters): def derived type + json I/O
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Oct 25, 2023
1 parent 02e450a commit b024b70
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/inference_engine/hyperparameters_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module hyperparameters_m
use sourcery_m, only : file_t
implicit none

private
public :: hyperparameters_t
public :: initialization_parameters_t
public :: initialization_t

type initialization_parameters_t
real spread_
end type

type initialization_t
character(len=:), allocatable :: initialization_type_
type(initialization_parameters_t) :: initialization_parameters_
end type

type hyperparameters_t
private
character(len=:), allocatable :: activation_
integer mini_batch_size_
integer, allocatable :: nodes_per_layer_(:)
type(initialization_t) initialization_
contains
procedure :: to_json
end type

interface hyperparameters_t

pure module function construct_from_json_file(file_) result(hyperparameters)
implicit none
type(file_t), intent(in) :: file_
type(hyperparameters_t) hyperparameters
end function

end interface

interface

impure elemental module function to_json(self) result(json_file)
implicit none
class(hyperparameters_t), intent(in) :: self
type(file_t) json_file
end function

end interface

end module
54 changes: 54 additions & 0 deletions src/inference_engine/hyperparameters_s.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
submodule(hyperparameters_m) hyperparameters_s
use assert_m, only : assert, intrinsic_array_t
use sourcery_m, only : string_t
implicit none

contains

module procedure construct_from_json_file
integer l
type(string_t), allocatable :: lines(:)

lines = file_%lines()

l = 1
call assert(adjustl(lines(l)%string())=="{", 'construct_from_json_file: adjustl(lines(l)%string())=="{"', lines(l)%string())

!{
! "activation" : "sigmoid",
! "num_mini_batches" : 10,
! "nodes per layer" : [2, 72, 2],
! "initialization" : {
! "type" : "perturbed identity",
! "parameters" : [ { "spread" : 0.05 } ]
! }
!}


l = l + 1
call assert(adjustl(lines(l)%string())=="}", 'construct_from_json_file: adjustl(lines(l)%string())=="}"', lines(l)%string())
end procedure

module procedure to_json
type(string_t), allocatable :: lines(:)
integer, parameter :: outer_object_braces = 2
integer, parameter :: num_lines = outer_object_braces
integer l

allocate(lines(num_lines))

l = 1
lines(l) = string_t('{')

l = l + 1
!lines(line) = string_t(' "modelName": "' // &
!self%metadata_(findloc(key, "modelName", dim=1))%string() // '",')



l = l + 1
call assert(l == num_lines, "hyperparameters_s(to_json): l == num_lines", intrinsic_array_t([l,num_lines]))
lines(l) = string_t('}')
end procedure

end submodule hyperparameters_s
57 changes: 57 additions & 0 deletions test/hyperparameters_test_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt
module hyperparameters_test_m
!! Test hyperparameters_t object I/O and construction

! External dependencies
use assert_m, only : assert
use sourcery_m, only : string_t, test_t, test_result_t

! Internal dependencies
use hyperparameters_m, only : hyperparameters_t

implicit none

private
public :: hyperparameters_test_t

type, extends(test_t) :: hyperparameters_test_t
contains
procedure, nopass :: subject
procedure, nopass :: results
end type

contains

pure function subject() result(specimen)
character(len=:), allocatable :: specimen
specimen = "A hyperparameters_t object"
end function

function results() result(test_results)
type(test_result_t), allocatable :: test_results(:)

character(len=*), parameter :: longest_description = &
"writing and then reading gives input matching output for perturbed identity network"

associate( &
descriptions => &
[ character(len=len(longest_description)) :: &
"writing and then reading gives input matching output for perturbed identity network" &
], &
outcomes => &
[ write_then_read_perturbed_identity() &
] &
)
call assert(size(descriptions) == size(outcomes),"hyperparameters_test_m(results): size(descriptions) == size(outcomes)")
test_results = test_result_t(descriptions, outcomes)
end associate

end function

function write_then_read_perturbed_identity() result(test_passes)
logical, allocatable :: test_passes(:)
test_passes = [.true.]
end function

end module hyperparameters_test_m
3 changes: 3 additions & 0 deletions test/main.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ program main
use inference_engine_test_m, only : inference_engine_test_t
use asymmetric_engine_test_m, only : asymmetric_engine_test_t
use trainable_engine_test_m, only : trainable_engine_test_t
use hyperparameters_test_m, only : hyperparameters_test_t
implicit none

type(inference_engine_test_t) inference_engine_test
type(asymmetric_engine_test_t) asymmetric_engine_test
type(trainable_engine_test_t) trainable_engine_test
type(hyperparameters_test_t) hyperparameters_test
real t_start, t_finish

integer :: passes=0, tests=0
Expand All @@ -18,6 +20,7 @@ program main
call inference_engine_test%report(passes, tests)
call asymmetric_engine_test%report(passes, tests)
call trainable_engine_test%report(passes, tests)
call hyperparameters_test%report(passes, tests)
#ifndef __INTEL_FORTRAN
block
use netCDF_file_test_m, only : netCDF_file_test_t
Expand Down
9 changes: 9 additions & 0 deletions training_parameters.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"activation" : "sigmoid",
"num_mini_batches" : 10,
"nodes per layer" : [2, 72, 2],
"initialization" : {
"type" : "perturbed identity",
"parameters" : [ { "spread" : 0.05 } ]
}
}

0 comments on commit b024b70

Please sign in to comment.