From a26c28ceac84f2aec735b08a1a459584483832ab Mon Sep 17 00:00:00 2001 From: Damian Rouson Date: Sun, 10 Sep 2023 01:04:59 -0700 Subject: [PATCH] feat(app): train in strided epochs This commit edits or adds command-line arguments of the form --base --stride --epochs where stride --base was previously --base-name, --stride is the step interval between time instances that will be used for training, and --epochs is the number of complete passes through the data set for the current run. --- app/train-cloud-microphysics.f90 | 162 ++++++++++++++++--------------- 1 file changed, 84 insertions(+), 78 deletions(-) diff --git a/app/train-cloud-microphysics.f90 b/app/train-cloud-microphysics.f90 index 3ecf4c888..adda20a38 100644 --- a/app/train-cloud-microphysics.f90 +++ b/app/train-cloud-microphysics.f90 @@ -44,29 +44,41 @@ program train_cloud_microphysics integer(int64) t_start, t_finish, clock_rate type(command_line_t) command_line - character(len=:), allocatable :: base_name, steps - integer plot_file, dash, starting_step, ending_step - logical user_specified_time_range + type(file_t) plot_file + type(string_t), allocatable :: lines(:) + character(len=*), parameter :: plot_file_name = "cost.plt" + character(len=:), allocatable :: base_name, stride_string, epochs_string, last_line + integer plot_unit, stride, starting_epoch, ending_epoch, num_epochs, last_epoch_in_file + logical preexisting_plot_file call system_clock(t_start, clock_rate) - base_name = command_line%flag_value("--base-name") ! gfortran 13 seg faults if this is an association - steps = command_line%flag_value("--steps") + base_name = command_line%flag_value("--base") ! gfortran 13 seg faults if this is an association + stride_string = command_line%flag_value("--stride") + epochs_string = command_line%flag_value("--epochs") - if (len(base_name)==0) error stop new_line('a') // new_line('a') // & - 'Usage: ./build/run-fpm.sh run train-cloud-microphysics -- --base-name ""' + if (any([len(base_name),len(stride_string),len(epochs_string)]==0)) error stop new_line('a') // new_line('a') // & + 'Usage: ./build/run-fpm.sh run train-cloud-microphysics -- --base --stride --epochs ' - if (len(steps)==0) then - user_specified_time_range = .false. - print *,"No user-specified time step range. All steps will be used." + read(stride_string,*) stride + read(epochs_string,*) num_epochs + + inquire(file=plot_file_name, exist=preexisting_plot_file) + open(newunit=plot_unit,file="cost.plt",status="unknown",position="append") + + if (.not. preexisting_plot_file) then + write(plot_unit,*) " Epoch Cost (min) Cost (max) Cost (avg)" + starting_epoch = 1 else - dash = scan(steps,"-") - read(steps(1:dash-1),*) starting_step - read(steps(dash+1:len(steps)),*) ending_step - user_specified_time_range = .true. - print *,"User-specified time step range: ", starting_step, "-", ending_step + plot_file = file_t(string_t(plot_file_name)) + lines = plot_file%lines() + last_line = lines(size(lines))%string() + read(last_line,*) last_epoch_in_file + starting_epoch = last_epoch_in_file + 1 end if + ending_epoch = starting_epoch + num_epochs - 1 + call read_train_write call system_clock(t_finish) @@ -174,91 +186,85 @@ subroutine read_train_write real(rkind), allocatable :: cost(:) real(rkind), allocatable :: harvest(:) integer, parameter :: mini_batch_size=1 - integer i, batch, lon, lat, level, time, file_unit, io_status, final_step + integer i, batch, lon, lat, level, time, network_unit, io_status, final_step - open(newunit=file_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read') + open(newunit=network_unit, file=network_file, form='formatted', status='old', iostat=io_status, action='read') if (io_status==0) then print *,"Reading network from file " // network_file trainable_engine = trainable_engine_t(inference_engine_t(file_t(string_t(network_file)))) - close(file_unit) + close(network_unit) else - close(file_unit) + close(network_unit) print *,"Initializing a new network" trainable_engine = new_engine(num_hidden_layers=12, nodes_per_hidden_layer=16, num_inputs=8, num_outputs=6, random=.true.) end if - open(newunit=plot_file,file="cost.plt",status="replace") - write(plot_file,*) "step min(cost) max(cost) avg(cost)" - - if (.not. user_specified_time_range) then - starting_step = 1 - ending_step = size(time_in) - end if - - print *,"Training based on data from step", starting_step, "to", ending_step - - do time = starting_step, ending_step - - if (time==1) print *,"Defining tensors for step ",time - - ! The following temporary copies are required by gfortran bug 100650 and possibly 49324 - ! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324 - inputs = [( [( [( & - tensor_t( & - [ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time),& - qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qi_in(lon,lat,level,time), qr_in(lon,lat,level,time), & - qs_in(lon,lat,level,time) & + print *,"Defining tensors from time steps 1 through", t_end, "with strides of", stride + + ! The following temporary copies are required by gfortran bug 100650 and possibly 49324 + ! See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100650 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=49324 + inputs = [( [( [( [( & + tensor_t( & + [ pressure_in(lon,lat,level,time), potential_temperature_in(lon,lat,level,time), temperature_in(lon,lat,level,time),& + qv_in(lon,lat,level,time), qc_in(lon,lat,level,time), qi_in(lon,lat,level,time), qr_in(lon,lat,level,time), & + qs_in(lon,lat,level,time) & + ] & + ), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))], time = 1, t_end, stride)] + + outputs = [( [( [( [( & + tensor_t( & + [dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), & + dqi_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), dqs_dt(lon,lat,level,time) & ] & - ), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))] - - outputs = [( [( [( & - tensor_t( & - [dpt_dt(lon,lat,level,time), dqv_dt(lon,lat,level,time), dqc_dt(lon,lat,level,time), & - dqi_dt(lon,lat,level,time), dqr_dt(lon,lat,level,time), dqs_dt(lon,lat,level,time) & - ] & - ), lon = 1, size(qv_in,1))], lat = 1, size(qv_in,2))], level = 1, size(qv_in,3))] - - if (time==1) print *, "Eliminating",int(100*(1.-keep)),"% of the grid points that have all-zero time derivatives" - - associate(num_grid_pts => size(outputs)) - if (allocated(harvest)) deallocate(harvest) - allocate(harvest(num_grid_pts)) - call random_number(harvest) - associate(keepers => [(any(outputs(i)%values()/=0.) .or. harvest(i) size(outputs)) + if (allocated(harvest)) deallocate(harvest) + allocate(harvest(num_grid_pts)) + call random_number(harvest) + associate(keepers => [(any(outputs(i)%values()/=0.) .or. harvest(i) size(input_output_pairs), n_bins => size(input_output_pairs)/10000) - bins = [(bin_t(num_items=num_pairs, num_bins=n_bins, bin_number=b), b = 1, n_bins)] - mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))] - end associate + associate(num_pairs => size(input_output_pairs), n_bins => size(input_output_pairs)/10000) + bins = [(bin_t(num_items=num_pairs, num_bins=n_bins, bin_number=b), b = 1, n_bins)] + mini_batches = [(mini_batch_t(input_output_pairs(bins(b)%first():bins(b)%last())), b = 1, size(bins))] + end associate - if (time==1) print *,"Training network" + print *,"Training network" - call trainable_engine%train(mini_batches, cost) + block + integer epoch - print *, "step, cost_{min,max,avg}: ", time, minval(cost), maxval(cost), sum(cost)/size(cost) + print *, " Epoch Cost (min) Cost (max) Cost (avg)" - write(plot_file,*) time, minval(cost), maxval(cost), sum(cost)/size(cost) + do epoch = starting_epoch, ending_epoch - end do + call trainable_engine%train(mini_batches, cost) - open(newunit=file_unit, file=network_file, form='formatted', status='unknown', iostat=io_status, action='write') + print *, epoch, minval(cost), maxval(cost), sum(cost)/size(cost) + write(plot_unit,*) epoch, minval(cost), maxval(cost), sum(cost)/size(cost) - associate(inference_engine => trainable_engine%to_inference_engine()) - associate(json_file => inference_engine%to_json()) - print *,"Writing network to " // network_file - call json_file%write_lines(string_t(network_file)) - end associate - end associate + open(newunit=network_unit, file=network_file, form='formatted', status='unknown', iostat=io_status, action='write') + associate(inference_engine => trainable_engine%to_inference_engine()) + associate(json_file => inference_engine%to_json()) + print *,"Writing network to " // network_file + call json_file%write_lines(string_t(network_file)) + end associate + end associate + close(network_unit) + + end do + end block - close(file_unit) - close(plot_file) + close(plot_unit) end block train_network