diff --git a/models/weight_optimizer.cpp b/models/weight_optimizer.cpp index 12f264c724..eed7dd3dad 100644 --- a/models/weight_optimizer.cpp +++ b/models/weight_optimizer.cpp @@ -34,6 +34,8 @@ namespace nest WeightOptimizerCommonProperties::WeightOptimizerCommonProperties() : batch_size_( 1 ) , eta_( 1e-4 ) + , eta_first_( 1e-4 ) + , n_eta_change_( 0 ) , Wmin_( -100.0 ) , Wmax_( 100.0 ) , optimize_each_step_( true ) @@ -43,6 +45,8 @@ WeightOptimizerCommonProperties::WeightOptimizerCommonProperties() WeightOptimizerCommonProperties::WeightOptimizerCommonProperties( const WeightOptimizerCommonProperties& cp ) : batch_size_( cp.batch_size_ ) , eta_( cp.eta_ ) + , eta_first_( cp.eta_first_ ) + , n_eta_change_( cp.n_eta_change_ ) , Wmin_( cp.Wmin_ ) , Wmax_( cp.Wmax_ ) , optimize_each_step_( cp.optimize_each_step_ ) @@ -77,6 +81,16 @@ WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d ) { throw BadProperty( "Learning rate eta ≥ 0 required." ); } + + if ( new_eta != eta_ ) + { + if ( n_eta_change_ == 0 ) + { + eta_first_ = new_eta; + } + n_eta_change_ += 1; + } + eta_ = new_eta; double new_Wmin = Wmin_; @@ -96,6 +110,8 @@ WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d ) WeightOptimizer::WeightOptimizer() : sum_gradients_( 0.0 ) , optimization_step_( 1 ) + , eta_( 1e-4 ) + , n_optimize_( 0 ) { } @@ -110,11 +126,15 @@ WeightOptimizer::set_status( const DictionaryDatum& d ) } double -WeightOptimizer::optimized_weight( const WeightOptimizerCommonProperties& cp, +WeightOptimizer::optimized_weight( WeightOptimizerCommonProperties& cp, const size_t idx_current_update, const double gradient, double weight ) { + if ( cp.n_eta_change_ != 0 and n_optimize_ == 0 ) + { + eta_ = cp.eta_first_; + } sum_gradients_ += gradient; if ( optimization_step_ == 0 ) @@ -127,6 +147,8 @@ WeightOptimizer::optimized_weight( const WeightOptimizerCommonProperties& cp, { sum_gradients_ /= cp.batch_size_; weight = std::max( cp.Wmin_, std::min( optimize_( cp, weight, current_optimization_step ), cp.Wmax_ ) ); + eta_ = cp.eta_; + n_optimize_ += 1; optimization_step_ = current_optimization_step; } return weight; @@ -152,7 +174,7 @@ WeightOptimizerGradientDescent::WeightOptimizerGradientDescent() double WeightOptimizerGradientDescent::optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t ) { - weight -= cp.eta_ * sum_gradients_; + weight -= eta_ * sum_gradients_; sum_gradients_ = 0.0; return weight; } @@ -251,7 +273,7 @@ WeightOptimizerAdam::optimize_( const WeightOptimizerCommonProperties& cp, beta_1_power_ *= acp.beta_1_; beta_2_power_ *= acp.beta_2_; - const double alpha = cp.eta_ * std::sqrt( 1.0 - beta_2_power_ ) / ( 1.0 - beta_1_power_ ); + const double alpha = eta_ * std::sqrt( 1.0 - beta_2_power_ ) / ( 1.0 - beta_1_power_ ); m_ = acp.beta_1_ * m_ + ( 1.0 - acp.beta_1_ ) * sum_gradients_; v_ = acp.beta_2_ * v_ + ( 1.0 - acp.beta_2_ ) * sum_gradients_ * sum_gradients_; diff --git a/models/weight_optimizer.h b/models/weight_optimizer.h index c1e5aae6f8..d14205fcb3 100644 --- a/models/weight_optimizer.h +++ b/models/weight_optimizer.h @@ -207,9 +207,15 @@ class WeightOptimizerCommonProperties //! Size of an optimization batch. size_t batch_size_; - //! Learning rate. + //! Learning rate common to all synapses. double eta_; + //! First learning rate that differs from the default. + double eta_first_; + + //! Number of changes to the learning rate. + long n_eta_change_; + //! Minimal value for synaptic weight. double Wmin_; @@ -252,7 +258,7 @@ class WeightOptimizer virtual void set_status( const DictionaryDatum& d ); //! Return optimized weight based on current weight. - double optimized_weight( const WeightOptimizerCommonProperties& cp, + double optimized_weight( WeightOptimizerCommonProperties& cp, const size_t idx_current_update, const double gradient, double weight ); @@ -266,6 +272,12 @@ class WeightOptimizer //! Current optimization step, whereby optimization happens every batch_size_ steps. size_t optimization_step_; + + //! Learning rate private to the synapse. + double eta_; + + //! Number of optimizations. + long n_optimize_; }; /** diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py index b05c008339..76cd266b7c 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py @@ -117,7 +117,9 @@ # The task's temporal structure is then defined, once as time steps and once as durations in milliseconds. # Using a batch size larger than one aids the network in generalization, facilitating the solution to this task. # The original number of iterations requires distributed computing. Increasing the number of iterations -# enhances learning performance up to the point where overfitting occurs. +# enhances learning performance up to the point where overfitting occurs. If early stopping is enabled, the +# classification error is tested in regular intervals and the training stopped as soon as the error selected as +# stop criterion is reached. After training, the performance can be tested over a number of test iterations. batch_size = 32 # batch size, 64 in reference [2], 32 in the README to reference [2] n_iter = 50 # number of iterations, 2000 in reference [2] @@ -126,6 +128,16 @@ n_cues = 7 # number of cues given before decision prob_group = 0.3 # probability with which one input group is present +do_early_stopping = True # if True, stop training as soon as stop criterion fulfilled +n_validate_every = 10 # number of training iterations before validation +n_early_stop = 8 # number of iterations to average over to evaluate early stopping condition +stop_crit = 0.07 # error value corresponding to stop criterion for early stopping + +n_test = 4 # number of iterations for final test + +n_val = np.ceil(n_iter / n_validate_every) +n_iter_max = int(n_iter + n_val + (n_val - 1) * (n_early_stop + 1) + n_test) + steps = { "cue": 100, # time steps in one cue presentation "spacing": 50, # time steps of break between two cues @@ -136,7 +148,6 @@ steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals -steps["task"] = n_iter * batch_size * steps["sequence"] # time steps of task steps.update( { @@ -144,7 +155,7 @@ "delay_in_rec": 1, # connection delay between input and recurrent neurons "delay_rec_out": 1, # connection delay between recurrent and output neurons "delay_out_norm": 1, # connection delay between output neurons for normalization - "extension_sim": 1, # extra time step to close right-open simulation time interval in Simulate() + "extension_sim": 2, # extra time step to close right-open simulation time interval in Simulate() } ) @@ -152,8 +163,6 @@ steps["total_offset"] = steps["offset_gen"] + steps["delays"] # time steps of total offset -steps["sim"] = steps["task"] + steps["total_offset"] + steps["extension_sim"] # time steps of simulation - duration = {"step": 1.0} # ms, temporal resolution of the simulation duration.update({key: value * duration["step"] for key, value in steps.items()}) # ms, durations @@ -285,7 +294,6 @@ "interval": duration["step"], # interval between two recorded time points "record_from": ["V_m", "surrogate_gradient", "learning_signal"], # dynamic variables to record "start": duration["offset_gen"] + duration["delay_in_rec"], # start time of recording - "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], # stop time of recording "label": "multimeter_reg", } @@ -293,7 +301,6 @@ "interval": duration["step"], "record_from": params_mm_reg["record_from"] + ["V_th_adapt", "adaptation"], "start": duration["offset_gen"] + duration["delay_in_rec"], - "stop": duration["offset_gen"] + duration["delay_in_rec"] + duration["task"], "label": "multimeter_ad", } @@ -301,7 +308,6 @@ "interval": duration["step"], "record_from": ["V_m", "readout_signal", "readout_signal_unnorm", "target_signal", "error_signal"], "start": duration["total_offset"], - "stop": duration["total_offset"] + duration["task"], "label": "multimeter_out", } @@ -309,25 +315,21 @@ "senders": nrns_in[:n_record_w] + nrns_rec[:n_record_w], # limit senders to subsample weights to record "targets": nrns_rec[:n_record_w] + nrns_out, # limit targets to subsample weights to record from "start": duration["total_offset"], - "stop": duration["total_offset"] + duration["task"], "label": "weight_recorder", } params_sr_in = { "start": duration["offset_gen"], - "stop": duration["total_offset"] + duration["task"], "label": "spike_recorder_in", } params_sr_reg = { "start": duration["offset_gen"], - "stop": duration["total_offset"] + duration["task"], "label": "spike_recorder_reg", } params_sr_ad = { "start": duration["offset_gen"], - "stop": duration["total_offset"] + duration["task"], "label": "spike_recorder_ad", } @@ -376,7 +378,6 @@ def calculate_glorot_dist(fan_in, fan_out): "beta_1": 0.9, # exponential decay rate for 1st moment estimate of Adam optimizer "beta_2": 0.999, # exponential decay rate for 2nd moment raw estimate of Adam optimizer "epsilon": 1e-8, # small numerical stabilization constant of Adam optimizer - "eta": 5e-3, # learning rate "Wmin": -100.0, # pA, minimal limit of the synaptic weights "Wmax": 100.0, # pA, maximal limit of the synaptic weights }, @@ -384,6 +385,9 @@ def calculate_glorot_dist(fan_in, fan_out): "weight_recorder": wr, } +eta_test = 0.0 +eta_train = 5e-3 + params_syn_base = { "synapse_model": "eprop_synapse_bsshslm_2020", "delay": duration["step"], # ms, dendritic delay @@ -504,38 +508,57 @@ def generate_evidence_accumulation_input_output( return input_spike_bools, target_cues -input_spike_prob = 0.04 # spike probability of frozen input noise -dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 +def get_params_task_input_output(n_iter_interval): + input_spike_prob = 0.04 # spike probability of frozen input noise + dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 + + input_spike_bools_list = [] + target_cues_list = [] -input_spike_bools_list = [] -target_cues_list = [] + for _ in range(n_iter_interval): + input_spike_bools, target_cues = generate_evidence_accumulation_input_output( + batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps + ) + input_spike_bools_list.append(input_spike_bools) + target_cues_list.extend(target_cues) -for _ in range(n_iter): - input_spike_bools, target_cues = generate_evidence_accumulation_input_output( - batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps + input_spike_bools_arr = np.array(input_spike_bools_list).reshape( + n_iter_interval * batch_size * steps["sequence"], n_in + ) + timeline_task = ( + np.arange( + 0.0, + n_iter_interval * batch_size * duration["sequence"], + duration["step"], + ) + + duration["offset_gen"] ) - input_spike_bools_list.append(input_spike_bools) - target_cues_list.extend(target_cues) -input_spike_bools_arr = np.array(input_spike_bools_list).reshape(steps["task"], n_in) -timeline_task = np.arange(0.0, duration["task"], duration["step"]) + duration["offset_gen"] + params_gen_spk_in = [ + {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)} + for nrn_in_idx in range(n_in) + ] -params_gen_spk_in = [ - {"spike_times": timeline_task[input_spike_bools_arr[:, nrn_in_idx]].astype(dtype_in_spks)} - for nrn_in_idx in range(n_in) -] + target_rate_changes = np.zeros((n_out, batch_size * n_iter_interval)) + target_rate_changes[np.array(target_cues_list), np.arange(batch_size * n_iter_interval)] = 1 -target_rate_changes = np.zeros((n_out, batch_size * n_iter)) -target_rate_changes[np.array(target_cues_list), np.arange(batch_size * n_iter)] = 1 + params_gen_rate_target = [ + { + "amplitude_times": np.arange( + 0.0, + n_iter_interval * batch_size * duration["sequence"], + duration["sequence"], + ) + + duration["total_offset"], + "amplitude_values": target_rate_changes[nrn_out_idx], + } + for nrn_out_idx in range(n_out) + ] + + return params_gen_spk_in, params_gen_rate_target -params_gen_rate_target = [ - { - "amplitude_times": np.arange(0.0, duration["task"], duration["sequence"]) + duration["total_offset"], - "amplitude_values": target_rate_changes[nrn_out_idx], - } - for nrn_out_idx in range(n_out) -] +params_gen_spk_in, params_gen_rate_target = get_params_task_input_output(n_iter_max) #################### @@ -551,7 +574,7 @@ def generate_evidence_accumulation_input_output( # the last update interval, by sending a strong spike to all neurons that form the presynaptic side of an eprop # synapse. This step is required purely for technical reasons. -gen_spk_final_update = nest.Create("spike_generator", 1, {"spike_times": [duration["task"] + duration["delays"]]}) +gen_spk_final_update = nest.Create("spike_generator", 1) nest.Connect(gen_spk_final_update, nrns_in + nrns_rec, "all_to_all", {"weight": 1000.0}) @@ -579,12 +602,134 @@ def get_weights(pop_pre, pop_post): } # %% ########################################################################################################### -# Simulate -# ~~~~~~~~ -# We train the network by simulating for a set simulation time, determined by the number of iterations and the -# batch size and the length of one sequence. +# Simulate and evaluate +# ~~~~~~~~~~~~~~~~~~~~~ +# We train the network by simulating for a number of training iterations with the set learning rate. If early +# stopping is turned on, we evaluate the network's performance on the validation set in regular intervals and, +# if the error is below a certain threshold, we stop the training early. If the error is not below the +# threshold, we continue training until the end of the set number of iterations. Finally, we evaluate the +# network's performance on the test set. +# Furthermore, we evaluate the network's training error by calculating a loss - in this case, the cross-entropy +# error between the integrated recurrent network activity and the target rate. + + +class TrainingPipeline: + def __init__(self): + self.results_dict = { + "error": [], + "loss": [], + "iteration": [], + "label": [], + } + self.n_iter_sim = 0 + self.phase_label_previous = "" + self.error = 0 + self.k_iter = 0 + self.early_stop = False + + def evaluate(self): + events_mm_out = mm_out.get("events") + + readout_signal = events_mm_out["readout_signal"] # corresponds to softmax + target_signal = events_mm_out["target_signal"] + senders = events_mm_out["senders"] + times = events_mm_out["times"] + + cond1 = times > (self.n_iter_sim - 1) * batch_size * duration["sequence"] + duration["total_offset"] + cond2 = times <= self.n_iter_sim * batch_size * duration["sequence"] + duration["total_offset"] + idc = cond1 & cond2 + + readout_signal = np.array([readout_signal[idc][senders[idc] == i] for i in set(senders)]) + target_signal = np.array([target_signal[idc][senders[idc] == i] for i in set(senders)]) + + readout_signal = readout_signal.reshape((n_out, 1, batch_size, steps["sequence"])) + target_signal = target_signal.reshape((n_out, 1, batch_size, steps["sequence"])) + + readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] + target_signal = target_signal[:, :, :, -steps["learning_window"] :] + + loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) + + y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) + y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) + accuracy = np.mean((y_target == y_prediction), axis=1) + errors = 1.0 - accuracy + + self.results_dict["iteration"].append(self.n_iter_sim) + self.results_dict["error"].extend(errors) + self.results_dict["loss"].extend(loss) + self.results_dict["label"].append(self.phase_label_previous) + + self.error = errors[0] + + def run(self, phase_label, eta): + params_common_syn_eprop["optimizer"]["eta"] = eta + nest.SetDefaults("eprop_synapse_bsshslm_2020", params_common_syn_eprop) + + nest.Simulate(duration["extension_sim"]) + if self.n_iter_sim > 0: + self.evaluate() + + duration["sim"] = batch_size * duration["sequence"] - duration["extension_sim"] + + nest.Simulate(duration["sim"]) -nest.Simulate(duration["sim"]) + self.n_iter_sim += 1 + self.phase_label_previous = phase_label + + def run_training(self): + self.run("training", eta_train) + + def run_validation(self): + if do_early_stopping and self.k_iter % n_validate_every == 0: + self.run("validation", eta_test) + + def run_early_stopping(self): + if do_early_stopping and self.k_iter % n_validate_every == 0: + if self.k_iter > 0 and self.error < stop_crit: + errors_early_stop = [] + for _ in range(n_early_stop): + self.run("early-stopping", eta_test) + errors_early_stop.append(self.error) + + self.early_stop = np.mean(errors_early_stop) < stop_crit + + def run_test(self): + for _ in range(n_test): + self.run("test", eta_test) + + def simulate(self): + nest.Simulate(duration["total_offset"]) + + while self.k_iter < n_iter and not self.early_stop: + self.run_validation() + self.run_early_stopping() + self.run_training() + self.k_iter += 1 + + self.run_test() + + nest.Simulate(steps["extension_sim"]) + + self.evaluate() + + duration["task"] = self.n_iter_sim * batch_size * duration["sequence"] + duration["total_offset"] + + gen_spk_final_update.set({"spike_times": [duration["task"] + duration["extension_sim"] + 1]}) + + nest.Simulate(duration["delays"]) + + def get_results(self): + for k, v in self.results_dict.items(): + self.results_dict[k] = np.array(v) + return self.results_dict + + +training_pipeline = TrainingPipeline() +training_pipeline.simulate() + +results_dict = training_pipeline.get_results() +n_iter_sim = training_pipeline.n_iter_sim # %% ########################################################################################################### # Read out post-training weights @@ -610,32 +755,6 @@ def get_weights(pop_pre, pop_post): events_sr_ad = sr_ad.get("events") events_wr = wr.get("events") -# %% ########################################################################################################### -# Evaluate training error -# ~~~~~~~~~~~~~~~~~~~~~~~ -# We evaluate the network's training error by calculating a loss - in this case, the cross-entropy error between -# the integrated recurrent network activity and the target rate. - -readout_signal = events_mm_out["readout_signal"] # corresponds to softmax -target_signal = events_mm_out["target_signal"] -senders = events_mm_out["senders"] - -readout_signal = np.array([readout_signal[senders == i] for i in set(senders)]) -target_signal = np.array([target_signal[senders == i] for i in set(senders)]) - -readout_signal = readout_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) -target_signal = target_signal.reshape((n_out, n_iter, batch_size, steps["sequence"])) - -readout_signal = readout_signal[:, :, :, -steps["learning_window"] :] -target_signal = target_signal[:, :, :, -steps["learning_window"] :] - -loss = -np.mean(np.sum(target_signal * np.log(readout_signal), axis=0), axis=(1, 2)) - -y_prediction = np.argmax(np.mean(readout_signal, axis=3), axis=0) -y_target = np.argmax(np.mean(target_signal, axis=3), axis=0) -accuracy = np.mean((y_target == y_prediction), axis=1) -recall_errors = 1.0 - accuracy - # %% ########################################################################################################### # Plot results # ~~~~~~~~~~~~ @@ -649,6 +768,8 @@ def get_weights(pop_pre, pop_post): colors = { "blue": "#2854c5ff", "red": "#e04b40ff", + "green": "#25aa2cff", + "gold": "#f9c643ff", "white": "#ffffffff", } @@ -656,27 +777,30 @@ def get_weights(pop_pre, pop_post): { "axes.spines.right": False, "axes.spines.top": False, - "axes.prop_cycle": cycler(color=[colors["blue"], colors["red"]]), + "axes.prop_cycle": cycler(color=[colors[k] for k in ["blue", "red", "green", "gold"]]), } ) # %% ########################################################################################################### -# Plot training error -# ................... -# We begin with two plots visualizing the training error of the network: the loss and the recall error, both +# Plot error +# .......... +# We begin with two plots visualizing the error of the network: the loss and the recall error, both # plotted against the iterations. fig, axs = plt.subplots(2, 1, sharex=True) fig.suptitle("Training error") -axs[0].plot(range(1, n_iter + 1), loss) +for color, label in zip(colors, set(results_dict["label"])): + idc = results_dict["label"] == label + axs[0].scatter(results_dict["iteration"][idc], results_dict["loss"][idc], label=label) + axs[1].scatter(results_dict["iteration"][idc], results_dict["error"][idc], label=label) + axs[0].set_ylabel(r"$E = -\sum_{t,k} \pi_k^{*,t} \log \pi_k^t$") -axs[1].plot(range(1, n_iter + 1), recall_errors) axs[1].set_ylabel("recall errors") -axs[-1].set_xlabel("training iteration") -axs[-1].set_xlim(1, n_iter) +axs[-1].set_xlabel("iteration") +axs[-1].legend(bbox_to_anchor=(1.05, 0.5), loc="center left") axs[-1].xaxis.get_major_locator().set_params(integer=True) fig.tight_layout() @@ -711,7 +835,10 @@ def plot_spikes(ax, events, ylabel, xlims): for title, xlims in zip( ["Dynamic variables before training", "Dynamic variables after training"], - [(0, steps["sequence"]), (steps["task"] - steps["sequence"], steps["task"])], + [ + (0, steps["sequence"]), + ((n_iter_sim - 1) * batch_size * steps["sequence"], n_iter_sim * batch_size * steps["sequence"]), + ], ): fig, axs = plt.subplots(14, 1, sharex=True, figsize=(8, 14), gridspec_kw={"hspace": 0.4, "left": 0.2}) fig.suptitle(title) @@ -775,7 +902,7 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") -axs[-1].set_xlim(0, steps["task"]) +axs[-1].set_xlim(0, duration["task"]) fig.align_ylabels() fig.tight_layout()