Skip to content

Commit

Permalink
Add parameters to the doxygen description
Browse files Browse the repository at this point in the history
  • Loading branch information
akorgor committed Sep 13, 2024
1 parent 62a3d92 commit d2ffcdb
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 49 deletions.
233 changes: 190 additions & 43 deletions nestkernel/eprop_archiving_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
namespace nest
{
/**
* @brief Base class implementing archiving for node models supporting e-prop plasticity.
*
* Base class implementing an intermediate archiving node model for node models supporting e-prop plasticity
* according to Bellec et al. (2020) and supporting additional biological features described in Korcsak-Gorzo,
* Stapmanns, and Espinoza Valverde et al. (in preparation).
Expand All @@ -44,42 +46,64 @@ namespace nest
* e-prop plasticity. It further provides a set of get, write, and set functions
* for these histories and the hardcoded shifts to synchronize the factors of
* the plasticity rule.
*
* @tparam HistEntryT The type of history entry.
*/
template < typename HistEntryT >
class EpropArchivingNode : public Node
{

public:
//! Default constructor.
/**
* Constructs a new EpropArchivingNode object.
*/
EpropArchivingNode();

//! Copy constructor.
EpropArchivingNode( const EpropArchivingNode& );
/**
* Constructs a new EpropArchivingNode object by copying another EpropArchivingNode object.
*
* @param other The other object to copy.
*/
EpropArchivingNode( const EpropArchivingNode& other );

//! Initialize the update history and register the eprop synapse.
void register_eprop_connection( const bool is_bsshslm_2020_model = true ) override;

//! Register current update in the update history and deregister previous update.
void write_update_to_history( const long t_previous_update,
const long t_current_update,
const long eprop_isi_trace_cutoff = 0,
const bool erase = false ) override;

//! Get an iterator pointing to the update history entry of the given time step.
/**
* Retrieves the update history entry for a specific time step.
*
* @param time_step The time step.
* @return An iterator pointing to the update history for the specified time step.
*/
std::vector< HistEntryEpropUpdate >::iterator get_update_history( const long time_step );

//! Get an iterator pointing to the eprop history entry of the given time step.
/**
* Retrieves the eprop history entry for a specified time step.
*
* @param time_step The time step.
* @return An iterator pointing to the eprop history entry for the specified time step.
*/
typename std::vector< HistEntryT >::iterator get_eprop_history( const long time_step );

/**
* Erase e-prop history entries for update intervals during which no spikes were sent to the target neuron,
* @brief Erases the used eprop history for `bsshslm_2020` models.
*
* Erases e-prop history entries for update intervals during which no spikes were sent to the target neuron,
* and any entries older than the earliest time stamp required by the first update in the history.
*/
void erase_used_eprop_history();

/**
* Erase e-prop history entries between the last and penultimate updates if they exceed the inter-spike
* @brief Erases the used eprop history.
*
* Erases e-prop history entries between the last and penultimate updates if they exceed the inter-spike
* interval trace cutoff and any entries older than the earliest time stamp required by the first update.
*
* @param eprop_isi_trace_cutoff The cutoff value for the inter-spike integration of the eprop trace.
*/
void erase_used_eprop_history( const long eprop_isi_trace_cutoff );

Expand All @@ -99,16 +123,16 @@ class EpropArchivingNode : public Node
//! Offset since generator signals start from time step 1.
const long offset_gen_ = 1;

//! Connection delay from input to recurrent neurons.
//! Transmission delay from input to recurrent neurons.
const long delay_in_rec_ = 1;

//! Connection delay from recurrent to output neurons.
//! Transmission delay from recurrent to output neurons.
const long delay_rec_out_ = 1;

//! Connection delay between output neurons for normalization.
//! Transmission delay between output neurons for normalization.
const long delay_out_norm_ = 1;

//! Connection delay from output neurons to recurrent neurons.
//! Transmission delay from output neurons to recurrent neurons.
const long delay_out_rec_ = 1;
};

Expand All @@ -119,25 +143,45 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec
{

public:
//! Default constructor.
/**
* Constructs a new EpropArchivingNodeRecurrent object.
*/
EpropArchivingNodeRecurrent();

//! Copy constructor.
EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& );
/**
* Constructs an EpropArchivingNodeRecurrent object by copying another EpropArchivingNodeRecurrent object.
*
* @param other The EpropArchivingNodeRecurrent object to copy.
*/
EpropArchivingNodeRecurrent( const EpropArchivingNodeRecurrent& other );

/**
* Define pointer-to-member function type for surrogate gradient function.
* Defines the pointer-to-member function type for the surrogate gradient function.
*
* @note The typename is `surrogate_gradient_function`. All parentheses in the expression are required.
*/
typedef double (
EpropArchivingNodeRecurrent::*surrogate_gradient_function )( double, double, double, double, double );

//! Select the surrogate gradient function.
/**
* Selects a surrogate gradient function based on the specified name.
*
* @param surrogate_gradient_function_name The name of the surrogate gradient function.
* @return The selected surrogate gradient function.
*/
surrogate_gradient_function select_surrogate_gradient( const std::string& surrogate_gradient_function_name );

/**
* Compute the surrogate gradient with a piecewise linear function around the spike time (used, e.g., in Bellec
* et al., 2020).
* @brief Computes the surrogate gradient with a piecewise linear function around the spike time.
*
* The piecewise linear surrogate function is used, for example, in Bellec et al. (2020).
*
* @param r The number of remaining refractory steps.
* @param v_m The membrane voltage.
* @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage.
* @param beta The width scaling of the surrogate gradient function.
* @param gamma The height scaling of the surrogate gradient function.
* @return The surrogate gradient of the membrane voltage.
*/
double compute_piecewise_linear_surrogate_gradient( const double r,
const double v_m,
Expand All @@ -146,8 +190,18 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec
const double gamma );

/**
* Compute the surrogate gradient with an exponentially decaying function around the spike time (used, e.g., in
* Shrestha and Orchard, 2018).
* @brief Computes the surrogate gradient with an exponentially decaying function around the spike time.
*
* The exponential surrogate function is used, for example, in Shrestha and Orchard (2018).
*
* @param r The number of remaining refractory steps.
* @param v_m The membrane voltage.
* @param v_th The threshold membrane voltage. For adaptive neurons, this is the adaptive threshold.
* @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage.
* @param beta The width scaling of the surrogate gradient function.
* @param gamma The height scaling of the surrogate gradient function.
*
* @return The surrogate gradient of the membrane voltage.
*/
double compute_exponential_surrogate_gradient( const double r,
const double v_m,
Expand All @@ -156,62 +210,136 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec
const double gamma );

/**
* Compute the surrogate gradient with a function corresponding to the derivative of a fast sigmoid around the spike
* (used, e.g., in Zenke and Ganguli, 2018).
* @brief Computes the surrogate gradient with a function reflecting the derivative of a fast sigmoid around the spike
* time.
*
* The derivative of fast sigmoid surrogate function is used, for example, in Zenke and Ganguli (2018).
*
* @param r The number of remaining refractory steps.
* @param v_m The membrane voltage.
* @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage.
* @param beta The width scaling of the surrogate gradient function.
* @param gamma The height scaling of the surrogate gradient function.
*
* @return The surrogate gradient of the membrane voltage.
*/
double compute_fast_sigmoid_derivative_surrogate_gradient( const double r,
const double v_m,
const double v_th,
const double beta,
const double gamma );

//! Compute the surrogate gradient with an arctan function around the spike time (used, e.g., in Fang et al., 2021).
/**
* @brief Computes the surrogate gradient with an inverse tangent function around the spike time.
*
* The inverse tangent surrogate gradient function is used, for example, in Fang et al. (2021).
*
* @param r The number of remaining refractory steps.
* @param v_m The membrane voltage.
* @param v_th The spike threshold voltage. For adaptive neurons, the adaptive spike threshold voltage.
* @param beta The width scaling of the surrogate gradient function.
* @param gamma The height scaling of the surrogate gradient function.
*
* @return The surrogate gradient of the membrane voltage.
*/
double compute_arctan_surrogate_gradient( const double r,
const double v_m,
const double v_th,
const double beta,
const double gamma );

//! Create an entry for the given time step at the end of the eprop history.
/**
* Creates an entry for the specified time step at the end of the eprop history.
*
* @param time_step The time step.
*/
void append_new_eprop_history_entry( const long time_step );

//! Write the given surrogate gradient value to the history at the given time step.
/**
* Writes the surrogate gradient to the eprop history entry at the specified time step.
*
* @param time_step The time step.
* @param surrogate_gradient The surrogate gradient.
*/
void write_surrogate_gradient_to_history( const long time_step, const double surrogate_gradient );

/**
* Update the learning signal in the eprop history entry of the given time step by writing the value of the incoming
* learning signal to the history or adding it to the existing value in case of multiple readout neurons.
* @brief Writes the learning signal to the eprop history entry at the specifed time step.
*
* Updates the learning signal in the eprop history entry of the specified time step by writing the value of the
* incoming learning signal to the history or adding it to the existing value in case of multiple readout
* neurons.
*
* @param time_step The time step.
* @param learning_signal The learning signal.
* @param has_norm_step Flag indicating if an extra time step is used for communication between readout
* neurons to normalize the readout signal outputs, as for softmax.
*/
void write_learning_signal_to_history( const long time_step,
const double learning_signal,
const bool has_norm_step = true );

//! Create an entry in the firing rate regularization history for the current update.
/**
* Calculates the firing rate regularization for the current update and writes it to a new entry in the firing rate
* regularization history.
*
* @param t_current_update The current update time.
* @param f_target The target firing rate.
* @param c_reg The firing rate regularization coefficient.
*/
void write_firing_rate_reg_to_history( const long t_current_update, const double f_target, const double c_reg );

//! Calculate the current firing rate regularization and add the value to the learning signal.
void write_firing_rate_reg_to_history( const long t,
/**
* Calculates the current firing rate regularization and writes it to the eprop history at the specified time step.
*
* @param time_step The time step.
* @param z The spike state variable.
* @param f_target The target firing rate.
* @param kappa_reg The low-pass filter of the firing rate regularization.
* @param c_reg The firing rate regularization coefficient.
*/
void write_firing_rate_reg_to_history( const long time_step,
const double z,
const double f_target,
const double kappa_reg,
const double c_reg );

//! Get an iterator pointing to the firing rate regularization history of the given time step.
/**
* Retrieves the firing rate regularization at the specified time step from the firing rate regularization history.
*
* @param time_step The time step.
*
* @return The firing rate regularization at the specified time step.
*/
double get_firing_rate_reg_history( const long time_step );

//! Return learning signal from history for given time step or zero if time step not in history
/**
* Retrieves the learning signal from the eprop history at the specified time step.
*
* @param time_step The time step.
* @param has_norm_step Flag indicating if an extra time step is used for communication between readout neurons to
* normalize the readout signal outputs, as for softmax.
*
* @return The learning signal at the specified time step or zero if time step is not in the history.
*/
double get_learning_signal_from_history( const long time_step, const bool has_norm_step = true );

/**
* Erase parts of the firing rate regularization history for which the access counter in the update history has
* @brief Erases the history of the used firing rate regularization history.
*
* Erases parts of the firing rate regularization history for which the access counter in the update history has
* decreased to zero since no synapse needs them any longer.
*/
void erase_used_firing_rate_reg_history();

//! Count emitted spike for the firing rate regularization.
/**
* Counts an emitted spike for the firing rate regularization.
*/
void count_spike();

//! Reset spike count for the firing rate regularization.
/**
* Resets the spike count for the firing rate regularization.
*/
void reset_spike_count();

//! Firing rate regularization.
Expand All @@ -228,7 +356,7 @@ class EpropArchivingNodeRecurrent : public EpropArchivingNode< HistEntryEpropRec
std::vector< HistEntryEpropFiringRateReg > firing_rate_reg_history_;

/**
* Map names of surrogate gradients provided to corresponding pointers to member functions.
* Maps provided names of surrogate gradients to corresponding pointers to member functions.
*
* @todo In the long run, this map should be handled by a manager with proper registration functions,
* so that external modules can add their own gradient functions.
Expand All @@ -254,16 +382,35 @@ EpropArchivingNodeRecurrent::reset_spike_count()
class EpropArchivingNodeReadout : public EpropArchivingNode< HistEntryEpropReadout >
{
public:
//! Default constructor.
/**
* Constructs a new EpropArchivingNodeReadout object.
*/
EpropArchivingNodeReadout();

//! Copy constructor.
EpropArchivingNodeReadout( const EpropArchivingNodeReadout& );
/**
* Constructs a new EpropArchivingNodeReadout object by copying another EpropArchivingNodeReadout object.
*
* @param other The EpropArchivingNodeReadout object to copy.
*/
EpropArchivingNodeReadout( const EpropArchivingNodeReadout& other );

//! Create an entry for the given time step at the end of the eprop history.
/**
* Creates an entry for the specified time step at the end of the eprop history.
*
* @param time_step The time step.
* @param has_norm_step Flag indicating if an extra time step is used for communication between readout neurons to
* normalize the readout signal outputs, as for softmax.
*/
void append_new_eprop_history_entry( const long time_step, const bool has_norm_step = true );

//! Write the given error signal value to history at the given time step.
/**
* Writes the error signal to the eprop history at the specified time step.
*
* @param time_step The time step.
* @param error_signal The error signal.
* @param has_norm_step Flag indicating if an extra time step is used for communication between readout neurons to
* normalize the readout signal outputs, as for softmax.
*/
void
write_error_signal_to_history( const long time_step, const double error_signal, const bool has_norm_step = true );
};
Expand Down
Loading

0 comments on commit d2ffcdb

Please sign in to comment.