Skip to content

Commit

Permalink
Add adaptation mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
akorgor committed Sep 16, 2024
1 parent 5fd2294 commit 7689f36
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 9 deletions.
53 changes: 48 additions & 5 deletions models/eprop_iaf_psc_delta_adapt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void
RecordablesMap< eprop_iaf_psc_delta_adapt >::create()
{
// use standard names wherever you can for consistency!
insert_( names::adaptation, &eprop_iaf_psc_delta_adapt::get_adaptation_ );
insert_( names::V_th_adapt, &eprop_iaf_psc_delta_adapt::get_v_th_adapt_ );
insert_( names::learning_signal, &eprop_iaf_psc_delta_adapt::get_learning_signal_ );
insert_( names::surrogate_gradient, &eprop_iaf_psc_delta_adapt::get_surrogate_gradient_ );
insert_( names::V_m, &eprop_iaf_psc_delta_adapt::get_V_m_ );
Expand All @@ -72,7 +74,9 @@ RecordablesMap< eprop_iaf_psc_delta_adapt >::create()
* ---------------------------------------------------------------- */

nest::eprop_iaf_psc_delta_adapt::Parameters_::Parameters_()
: tau_m_( 10.0 ) // ms
: adapt_beta_( 1.0 )
, adapt_tau_( 10.0 )
, tau_m_( 10.0 ) // ms
, c_m_( 250.0 ) // pF
, t_ref_( 2.0 ) // ms
, E_L_( -70.0 ) // mV
Expand All @@ -93,7 +97,9 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::Parameters_()
}

nest::eprop_iaf_psc_delta_adapt::State_::State_()
: y0_( 0.0 )
: adapt_( 0.0 )
, v_th_adapt_( 15.0 )
, y0_( 0.0 )
, y3_( 0.0 )
, r_( 0 )
, refr_spikes_buffer_( 0.0 )
Expand All @@ -110,6 +116,8 @@ nest::eprop_iaf_psc_delta_adapt::State_::State_()
void
nest::eprop_iaf_psc_delta_adapt::Parameters_::get( DictionaryDatum& d ) const
{
def< double >( d, names::adapt_beta, adapt_beta_ );
def< double >( d, names::adapt_tau, adapt_tau_ );
def< double >( d, names::E_L, E_L_ ); // Resting potential
def< double >( d, names::I_e, I_e_ );
def< double >( d, names::V_th, V_th_ + E_L_ ); // threshold value
Expand Down Expand Up @@ -165,6 +173,8 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Nod
V_min_ -= delta_EL;
}

updateValueParam< double >( d, names::adapt_beta, adapt_beta_, node );
updateValueParam< double >( d, names::adapt_tau, adapt_tau_, node );
updateValueParam< double >( d, names::I_e, I_e_, node );
updateValueParam< double >( d, names::C_m, c_m_, node );
updateValueParam< double >( d, names::tau_m, tau_m_, node );
Expand All @@ -182,6 +192,16 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Nod
updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node );
updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node );

if ( adapt_beta_ < 0 )
{
throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." );
}

if ( adapt_tau_ <= 0 )
{
throw BadProperty( "Threshold adaptation time constant adapt_tau > 0 required." );
}

if ( V_reset_ >= V_th_ )
{
throw BadProperty( "Reset potential must be smaller than threshold." );
Expand Down Expand Up @@ -231,6 +251,8 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Nod
void
nest::eprop_iaf_psc_delta_adapt::State_::get( DictionaryDatum& d, const Parameters_& p ) const
{
def< double >( d, names::adaptation, adapt_ );
def< double >( d, names::V_th_adapt, v_th_adapt_ + p.E_L_ );
def< double >( d, names::V_m, y3_ + p.E_L_ ); // Membrane potential
def< double >( d, names::surrogate_gradient, surrogate_gradient_ );
def< double >( d, names::learning_signal, learning_signal_ );
Expand All @@ -250,6 +272,18 @@ nest::eprop_iaf_psc_delta_adapt::State_::set( const DictionaryDatum& d,
{
y3_ -= delta_EL;
}

// adaptive threshold can only be set indirectly via the adaptation variable
if ( updateValueParam< double >( d, names::adaptation, adapt_, node ) )
{
// if E_L changed in this SetStatus call, p.V_th_ has been adjusted and no further action is needed
v_th_adapt_ = p.V_th_ + p.adapt_beta_ * adapt_;
}
else
{
// adjust voltage to change in E_L
v_th_adapt_ -= delta_EL;
}
}

nest::eprop_iaf_psc_delta_adapt::Buffers_::Buffers_( eprop_iaf_psc_delta_adapt& n )
Expand Down Expand Up @@ -310,6 +344,7 @@ nest::eprop_iaf_psc_delta_adapt::pre_run_hook()
V_.P30_ = 1 / P_.c_m_ * ( 1 - V_.P33_ ) * P_.tau_m_;

V_.P_z_in_ = 1.0;
V_.P_adapt_ = std::exp( -h / P_.adapt_tau_ );

// t_ref_ specifies the length of the absolute refractory period as
// a double in ms. The grid based iaf_psp_delta can only handle refractory
Expand Down Expand Up @@ -389,12 +424,18 @@ nest::eprop_iaf_psc_delta_adapt::update( Time const& origin, const long from, co
--S_.r_;
}

S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient_ )( S_.r_, S_.y3_, P_.V_th_, P_.beta_, P_.gamma_ );

double z = 0.0; // spike state variable

S_.adapt_ = V_.P_adapt_ * S_.adapt_ + z;
S_.v_th_adapt_ = P_.V_th_ + P_.adapt_beta_ * S_.adapt_;

S_.surrogate_gradient_ =
( this->*compute_surrogate_gradient_ )( S_.r_, S_.y3_, S_.v_th_adapt_, P_.beta_, P_.gamma_ );


// threshold crossing
if ( S_.y3_ >= P_.V_th_ )
if ( S_.y3_ >= S_.v_th_adapt_ )
{
S_.r_ = V_.RefractoryCounts_;
S_.y3_ = P_.V_reset_;
Expand Down Expand Up @@ -502,7 +543,8 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike,
firing_rate_reg = eprop_hist_it->firing_rate_reg_;

z_bar = V_.P33_ * z_bar + V_.P_z_in_ * z;
e = psi * z_bar;
e = psi * ( z_bar - P_.adapt_beta_ * epsilon );
epsilon = V_.P_adapt_ * epsilon + e;
e_bar = P_.kappa_ * e_bar + ( 1.0 - P_.kappa_ ) * e;
e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e;

Expand All @@ -529,6 +571,7 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike,
z_bar *= std::pow( V_.P33_, cutoff_to_spike_interval );
e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval );
e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval );
epsilon *= std::pow( V_.P_adapt_, cutoff_to_spike_interval );
}
}

Expand Down
29 changes: 29 additions & 0 deletions models/eprop_iaf_psc_delta_adapt.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent
*/
struct Parameters_
{
//! Prefactor of the threshold adaptation.
double adapt_beta_;

//! Time constant of the threshold adaptation (ms).
double adapt_tau_;

/** Membrane time constant in ms. */
double tau_m_;

Expand Down Expand Up @@ -366,6 +372,12 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent
*/
struct State_
{
//! Adaptation variable.
double adapt_;

//! Adapting spike threshold voltage.
double v_th_adapt_;

double y0_;
//! This is the membrane potential RELATIVE TO RESTING POTENTIAL.
double y3_;
Expand Down Expand Up @@ -430,6 +442,9 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent
//! Propagator matrix entry for evolving the incoming spike variables.
double P_z_in_;

//! Propagator matrix entry for evolving the adaptation (mathematical symbol "rho" in user documentation).
double P_adapt_;

int RefractoryCounts_;

//! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes.
Expand Down Expand Up @@ -460,6 +475,20 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent
return S_.learning_signal_;
}

//! Get the current value of the adapting threshold.
double
get_v_th_adapt_() const
{
return S_.v_th_adapt_ + P_.E_L_;
}

//! Get the current value of the adaptation.
double
get_adaptation_() const
{
return S_.adapt_;
}

// ----------------------------------------------------------------

/**
Expand Down
8 changes: 4 additions & 4 deletions testsuite/pytests/test_eprop_plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def test_unsupported_model_raises(target_model):
"gradient_descent",
[
0.32286231964124,
0.61322219696014,
0.63745062813969,
0.63844466107304,
0.58671835471489,
0.53799160861627,
0.56775204949340,
0.52965220810013,
0.61350867996684,
],
),
],
Expand Down

0 comments on commit 7689f36

Please sign in to comment.