Skip to content

Commit

Permalink
Merge pull request #81 from BerkeleyLab/epoch
Browse files Browse the repository at this point in the history
feat(app): train in strided epochs
  • Loading branch information
rouson authored Sep 10, 2023
2 parents 1634577 + a26c28c commit 2fc0630
Showing 1 changed file with 84 additions and 78 deletions.
162 changes: 84 additions & 78 deletions app/train-cloud-microphysics.f90
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,41 @@ program train_cloud_microphysics

integer(int64) t_start, t_finish, clock_rate
type(command_line_t) command_line
character(len=:), allocatable :: base_name, steps
integer plot_file, dash, starting_step, ending_step
logical user_specified_time_range
type(file_t) plot_file
type(string_t), allocatable :: lines(:)
character(len=*), parameter :: plot_file_name = "cost.plt"
character(len=:), allocatable :: base_name, stride_string, epochs_string, last_line
integer plot_unit, stride, starting_epoch, ending_epoch, num_epochs, last_epoch_in_file
logical preexisting_plot_file

call system_clock(t_start, clock_rate)

base_name = command_line%flag_value("--base-name") ! gfortran 13 seg faults if this is an association
steps = command_line%flag_value("--steps")
base_name = command_line%flag_value("--base") ! gfortran 13 seg faults if this is an association
stride_string = command_line%flag_value("--stride")
epochs_string = command_line%flag_value("--epochs")

if (len(base_name)==0) error stop new_line('a') // new_line('a') // &
'Usage: ./build/run-fpm.sh run train-cloud-microphysics -- --base-name "<file-base-name>"'
if (any([len(base_name),len(stride_string),len(epochs_string)]==0)) error stop new_line('a') // new_line('a') // &
'Usage: ./build/run-fpm.sh run train-cloud-microphysics -- --base <string> --stride <integer> --epochs <integer>'

if (len(steps)==0) then
user_specified_time_range = .false.
print *,"No user-specified time step range. All steps will be used."
read(stride_string,*) stride
read(epochs_string,*) num_epochs

inquire(file=plot_file_name, exist=preexisting_plot_file)
open(newunit=plot_unit,file="cost.plt",status="unknown",position="append")

if (.not. preexisting_plot_file) then
write(plot_unit,*) " Epoch Cost (min) Cost (max) Cost (avg)"
starting_epoch = 1
else
dash = scan(steps,"-")
read(steps(1:dash-1),*) starting_step
read(steps(dash+1:len(steps)),*) ending_step
user_specified_time_range = .true.
print *,"User-specified time step range: ", starting_step, "-", ending_step
plot_file = file_t(string_t(plot_file_name))
lines = plot_file%lines()
last_line = lines(size(lines))%string()
read(last_line,*) last_epoch_in_file
starting_epoch = last_epoch_in_file + 1
end if

ending_epoch = starting_epoch + num_epochs - 1

call read_train_write

call system_clock(t_finish)
Expand Down Expand Up @@ -174,91 +186,85 @@ subroutine read_train_write
real(rkind), allocatable :: cost(:)
real(rkind), allocatable :: harvest(:)
integer, parameter :: mini_batch_size=1
integer i, batch, lon, lat, level, time, file_unit, io_status, final_step
integer i, batch, lon, lat, level, time, network_unit, io_status, final_step

open(newunit=file_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')
open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')

if (io_status==0) then
print *,"Reading network from file " // network_file
trainable_engine = trainable_engine_t(inference_engine_t(file_t(string_t(network_file))))
close(file_unit)
close(network_unit)
else
close(file_unit)
close(network_unit)
print *,"Initializing a new network"
trainable_engine = new_engine(num_hidden_layers=12, nodes_per_hidden_layer=16, num_inputs=8, num_outputs=6, random=.true.)
end if

open(newunit=plot_file,file="cost.plt",status="replace")
write(plot_file,*) "step min(cost) max(cost) avg(cost)"

if (.not. user_specified_time_range) then
starting_step = 1
ending_step = size(time_in)
end if

print *,"Training based on data from step", starting_step, "to", ending_step

do time = starting_step, ending_step

if (time==1) print *,"Defining tensors for step ",time

! The following temporary copies are required by gfortran bug 100650 and possibly 49324
! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
inputs = [( [( [( &
tensor_t( &
[ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time),&
qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qi_in(lon,lat,level,time), qr_in(lon,lat,level,time), &
qs_in(lon,lat,level,time) &
print *,"Defining tensors from time steps 1 through", t_end, "with strides of", stride

! The following temporary copies are required by gfortran bug 100650 and possibly 49324
! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324
inputs = [( [( [( [( &
tensor_t( &
[ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time),&
qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qi_in(lon,lat,level,time), qr_in(lon,lat,level,time), &
qs_in(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = 1, t_end, stride)]

outputs = [( [( [( [( &
tensor_t( &
[dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), &
dqi_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), dqs_dt(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))]

outputs = [( [( [( &
tensor_t( &
[dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), &
dqi_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), dqs_dt(lon,lat,level,time) &
] &
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))]

if (time==1) print *, "Eliminating",int(100*(1.-keep)),"% of the grid points that have all-zero time derivatives"

associate(num_grid_pts => size(outputs))
if (allocated(harvest)) deallocate(harvest)
allocate(harvest(num_grid_pts))
call random_number(harvest)
associate(keepers => [(any(outputs(i)%values()/=0.) .or. harvest(i)<keep, i=1,num_grid_pts)])
input_output_pairs = input_output_pair_t(pack(inputs, keepers), pack(outputs, keepers))
if (time==1) print *, size(input_output_pairs), "points retained out of ", num_grid_pts, " points total"
end associate
), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = 1, t_end, stride)]

print *, "Eliminating",int(100*(1.-keep)),"% of the grid points that have all-zero time derivatives"

associate(num_grid_pts => size(outputs))
if (allocated(harvest)) deallocate(harvest)
allocate(harvest(num_grid_pts))
call random_number(harvest)
associate(keepers => [(any(outputs(i)%values()/=0.) .or. harvest(i)<keep, i=1,num_grid_pts)])
input_output_pairs = input_output_pair_t(pack(inputs, keepers), pack(outputs, keepers))
print *, size(input_output_pairs), "points retained out of ", num_grid_pts, " points total"
end associate
end associate

call shuffle(input_output_pairs) ! set up for stochastic gradient descent
call shuffle(input_output_pairs) ! set up for stochastic gradient descent

associate(num_pairs => size(input_output_pairs), n_bins => size(input_output_pairs)/10000)
bins = [(bin_t(num_items=num_pairs, num_bins=n_bins, bin_number=b), b = 1, n_bins)]
mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))]
end associate
associate(num_pairs => size(input_output_pairs), n_bins => size(input_output_pairs)/10000)
bins = [(bin_t(num_items=num_pairs, num_bins=n_bins, bin_number=b), b = 1, n_bins)]
mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))]
end associate

if (time==1) print *,"Training network"
print *,"Training network"

call trainable_engine%train(mini_batches, cost)
block
integer epoch

print *, "step, cost_{min,max,avg}: ", time, minval(cost), maxval(cost), sum(cost)/size(cost)
print *, " Epoch Cost (min) Cost (max) Cost (avg)"

write(plot_file,*) time, minval(cost), maxval(cost), sum(cost)/size(cost)
do epoch = starting_epoch, ending_epoch

end do
call trainable_engine%train(mini_batches, cost)

open(newunit=file_unit, file=network_file, form='formatted', status='unknown', iostat=io_status, action='write')
print *, epoch, minval(cost), maxval(cost), sum(cost)/size(cost)
write(plot_unit,*) epoch, minval(cost), maxval(cost), sum(cost)/size(cost)

associate(inference_engine => trainable_engine%to_inference_engine())
associate(json_file => inference_engine%to_json())
print *,"Writing network to " // network_file
call json_file%write_lines(string_t(network_file))
end associate
end associate
open(newunit=network_unit, file=network_file, form='formatted', status='unknown', iostat=io_status, action='write')
associate(inference_engine => trainable_engine%to_inference_engine())
associate(json_file => inference_engine%to_json())
print *,"Writing network to " // network_file
call json_file%write_lines(string_t(network_file))
end associate
end associate
close(network_unit)

end do
end block

close(file_unit)
close(plot_file)
close(plot_unit)

end block train_network

Expand Down

0 comments on commit 2fc0630

Please sign in to comment.