Skip to content

Commit

Permalink
feat(app): reshuffle training data for each epoch
Browse files Browse the repository at this point in the history
The goal of this change is to nudge the stochastic gradient
algorithm away form local minima in the cost function.
  • Loading branch information
rouson committed Sep 10, 2023
1 parent 2fc0630 commit 09e0928
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions app/train-cloud-microphysics.f90
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ subroutine read_train_write
real(rkind), allocatable :: cost(:)
real(rkind), allocatable :: harvest(:)
integer, parameter :: mini_batch_size=1
integer i, batch, lon, lat, level, time, network_unit, io_status, final_step
integer i, batch, lon, lat, level, time, network_unit, io_status, final_step, epoch

open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read')

Expand Down Expand Up @@ -231,24 +231,18 @@ subroutine read_train_write
end associate
end associate

call shuffle(input_output_pairs) ! set up for stochastic gradient descent

associate(num_pairs => size(input_output_pairs), n_bins => size(input_output_pairs)/10000)
bins = [(bin_t(num_items=num_pairs, num_bins=n_bins, bin_number=b), b = 1, n_bins)]
mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))]
end associate

print *,"Training network"

block
integer epoch

print *,"Training network"
print *, " Epoch Cost (min) Cost (max) Cost (avg)"

do epoch = starting_epoch, ending_epoch

call shuffle(input_output_pairs) ! set up for stochastic gradient descent
mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))]
call trainable_engine%train(mini_batches, cost)

print *, epoch, minval(cost), maxval(cost), sum(cost)/size(cost)
write(plot_unit,*) epoch, minval(cost), maxval(cost), sum(cost)/size(cost)

Expand All @@ -262,7 +256,7 @@ subroutine read_train_write
close(network_unit)

end do
end block
end associate

close(plot_unit)

Expand Down

0 comments on commit 09e0928

Please sign in to comment.