From 7689f366c0d18935e2211ca279211ed3a283880b Mon Sep 17 00:00:00 2001 From: Agnes Korcsak-Gorzo <40828647+akorgor@users.noreply.github.com> Date: Sat, 14 Sep 2024 12:27:12 +0200 Subject: [PATCH] Add adaptation mechanism --- models/eprop_iaf_psc_delta_adapt.cpp | 53 ++++++++++++++++++++-- models/eprop_iaf_psc_delta_adapt.h | 29 ++++++++++++ testsuite/pytests/test_eprop_plasticity.py | 8 ++-- 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/models/eprop_iaf_psc_delta_adapt.cpp b/models/eprop_iaf_psc_delta_adapt.cpp index c848bd9c3f..d721c7a6a5 100644 --- a/models/eprop_iaf_psc_delta_adapt.cpp +++ b/models/eprop_iaf_psc_delta_adapt.cpp @@ -62,6 +62,8 @@ void RecordablesMap< eprop_iaf_psc_delta_adapt >::create() { // use standard names wherever you can for consistency! + insert_( names::adaptation, &eprop_iaf_psc_delta_adapt::get_adaptation_ ); + insert_( names::V_th_adapt, &eprop_iaf_psc_delta_adapt::get_v_th_adapt_ ); insert_( names::learning_signal, &eprop_iaf_psc_delta_adapt::get_learning_signal_ ); insert_( names::surrogate_gradient, &eprop_iaf_psc_delta_adapt::get_surrogate_gradient_ ); insert_( names::V_m, &eprop_iaf_psc_delta_adapt::get_V_m_ ); @@ -72,7 +74,9 @@ RecordablesMap< eprop_iaf_psc_delta_adapt >::create() * ---------------------------------------------------------------- */ nest::eprop_iaf_psc_delta_adapt::Parameters_::Parameters_() - : tau_m_( 10.0 ) // ms + : adapt_beta_( 1.0 ) + , adapt_tau_( 10.0 ) + , tau_m_( 10.0 ) // ms , c_m_( 250.0 ) // pF , t_ref_( 2.0 ) // ms , E_L_( -70.0 ) // mV @@ -93,7 +97,9 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::Parameters_() } nest::eprop_iaf_psc_delta_adapt::State_::State_() - : y0_( 0.0 ) + : adapt_( 0.0 ) + , v_th_adapt_( 15.0 ) + , y0_( 0.0 ) , y3_( 0.0 ) , r_( 0 ) , refr_spikes_buffer_( 0.0 ) @@ -110,6 +116,8 @@ nest::eprop_iaf_psc_delta_adapt::State_::State_() void nest::eprop_iaf_psc_delta_adapt::Parameters_::get( DictionaryDatum& d ) const { + def< double >( d, names::adapt_beta, adapt_beta_ ); + def< double >( d, names::adapt_tau, adapt_tau_ ); def< double >( d, names::E_L, E_L_ ); // Resting potential def< double >( d, names::I_e, I_e_ ); def< double >( d, names::V_th, V_th_ + E_L_ ); // threshold value @@ -165,6 +173,8 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Nod V_min_ -= delta_EL; } + updateValueParam< double >( d, names::adapt_beta, adapt_beta_, node ); + updateValueParam< double >( d, names::adapt_tau, adapt_tau_, node ); updateValueParam< double >( d, names::I_e, I_e_, node ); updateValueParam< double >( d, names::C_m, c_m_, node ); updateValueParam< double >( d, names::tau_m, tau_m_, node ); @@ -182,6 +192,16 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Nod updateValueParam< double >( d, names::kappa_reg, kappa_reg_, node ); updateValueParam< double >( d, names::eprop_isi_trace_cutoff, eprop_isi_trace_cutoff_, node ); + if ( adapt_beta_ < 0 ) + { + throw BadProperty( "Threshold adaptation prefactor adapt_beta ≥ 0 required." ); + } + + if ( adapt_tau_ <= 0 ) + { + throw BadProperty( "Threshold adaptation time constant adapt_tau > 0 required." ); + } + if ( V_reset_ >= V_th_ ) { throw BadProperty( "Reset potential must be smaller than threshold." ); @@ -231,6 +251,8 @@ nest::eprop_iaf_psc_delta_adapt::Parameters_::set( const DictionaryDatum& d, Nod void nest::eprop_iaf_psc_delta_adapt::State_::get( DictionaryDatum& d, const Parameters_& p ) const { + def< double >( d, names::adaptation, adapt_ ); + def< double >( d, names::V_th_adapt, v_th_adapt_ + p.E_L_ ); def< double >( d, names::V_m, y3_ + p.E_L_ ); // Membrane potential def< double >( d, names::surrogate_gradient, surrogate_gradient_ ); def< double >( d, names::learning_signal, learning_signal_ ); @@ -250,6 +272,18 @@ nest::eprop_iaf_psc_delta_adapt::State_::set( const DictionaryDatum& d, { y3_ -= delta_EL; } + + // adaptive threshold can only be set indirectly via the adaptation variable + if ( updateValueParam< double >( d, names::adaptation, adapt_, node ) ) + { + // if E_L changed in this SetStatus call, p.V_th_ has been adjusted and no further action is needed + v_th_adapt_ = p.V_th_ + p.adapt_beta_ * adapt_; + } + else + { + // adjust voltage to change in E_L + v_th_adapt_ -= delta_EL; + } } nest::eprop_iaf_psc_delta_adapt::Buffers_::Buffers_( eprop_iaf_psc_delta_adapt& n ) @@ -310,6 +344,7 @@ nest::eprop_iaf_psc_delta_adapt::pre_run_hook() V_.P30_ = 1 / P_.c_m_ * ( 1 - V_.P33_ ) * P_.tau_m_; V_.P_z_in_ = 1.0; + V_.P_adapt_ = std::exp( -h / P_.adapt_tau_ ); // t_ref_ specifies the length of the absolute refractory period as // a double in ms. The grid based iaf_psp_delta can only handle refractory @@ -389,12 +424,18 @@ nest::eprop_iaf_psc_delta_adapt::update( Time const& origin, const long from, co --S_.r_; } - S_.surrogate_gradient_ = ( this->*compute_surrogate_gradient_ )( S_.r_, S_.y3_, P_.V_th_, P_.beta_, P_.gamma_ ); double z = 0.0; // spike state variable + S_.adapt_ = V_.P_adapt_ * S_.adapt_ + z; + S_.v_th_adapt_ = P_.V_th_ + P_.adapt_beta_ * S_.adapt_; + + S_.surrogate_gradient_ = + ( this->*compute_surrogate_gradient_ )( S_.r_, S_.y3_, S_.v_th_adapt_, P_.beta_, P_.gamma_ ); + + // threshold crossing - if ( S_.y3_ >= P_.V_th_ ) + if ( S_.y3_ >= S_.v_th_adapt_ ) { S_.r_ = V_.RefractoryCounts_; S_.y3_ = P_.V_reset_; @@ -502,7 +543,8 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike, firing_rate_reg = eprop_hist_it->firing_rate_reg_; z_bar = V_.P33_ * z_bar + V_.P_z_in_ * z; - e = psi * z_bar; + e = psi * ( z_bar - P_.adapt_beta_ * epsilon ); + epsilon = V_.P_adapt_ * epsilon + e; e_bar = P_.kappa_ * e_bar + ( 1.0 - P_.kappa_ ) * e; e_bar_reg = P_.kappa_reg_ * e_bar_reg + ( 1.0 - P_.kappa_reg_ ) * e; @@ -529,6 +571,7 @@ eprop_iaf_psc_delta_adapt::compute_gradient( const long t_spike, z_bar *= std::pow( V_.P33_, cutoff_to_spike_interval ); e_bar *= std::pow( P_.kappa_, cutoff_to_spike_interval ); e_bar_reg *= std::pow( P_.kappa_reg_, cutoff_to_spike_interval ); + epsilon *= std::pow( V_.P_adapt_, cutoff_to_spike_interval ); } } diff --git a/models/eprop_iaf_psc_delta_adapt.h b/models/eprop_iaf_psc_delta_adapt.h index 4ff9f68b00..36e39d6fd7 100644 --- a/models/eprop_iaf_psc_delta_adapt.h +++ b/models/eprop_iaf_psc_delta_adapt.h @@ -295,6 +295,12 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent */ struct Parameters_ { + //! Prefactor of the threshold adaptation. + double adapt_beta_; + + //! Time constant of the threshold adaptation (ms). + double adapt_tau_; + /** Membrane time constant in ms. */ double tau_m_; @@ -366,6 +372,12 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent */ struct State_ { + //! Adaptation variable. + double adapt_; + + //! Adapting spike threshold voltage. + double v_th_adapt_; + double y0_; //! This is the membrane potential RELATIVE TO RESTING POTENTIAL. double y3_; @@ -430,6 +442,9 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent //! Propagator matrix entry for evolving the incoming spike variables. double P_z_in_; + //! Propagator matrix entry for evolving the adaptation (mathematical symbol "rho" in user documentation). + double P_adapt_; + int RefractoryCounts_; //! Time steps from the previous spike until the cutoff of e-prop update integration between two spikes. @@ -460,6 +475,20 @@ class eprop_iaf_psc_delta_adapt : public EpropArchivingNodeRecurrent return S_.learning_signal_; } + //! Get the current value of the adapting threshold. + double + get_v_th_adapt_() const + { + return S_.v_th_adapt_ + P_.E_L_; + } + + //! Get the current value of the adaptation. + double + get_adaptation_() const + { + return S_.adapt_; + } + // ---------------------------------------------------------------- /** diff --git a/testsuite/pytests/test_eprop_plasticity.py b/testsuite/pytests/test_eprop_plasticity.py index 94d2e63753..bd2103fc08 100644 --- a/testsuite/pytests/test_eprop_plasticity.py +++ b/testsuite/pytests/test_eprop_plasticity.py @@ -105,10 +105,10 @@ def test_unsupported_model_raises(target_model): "gradient_descent", [ 0.32286231964124, - 0.61322219696014, - 0.63745062813969, - 0.63844466107304, - 0.58671835471489, + 0.53799160861627, + 0.56775204949340, + 0.52965220810013, + 0.61350867996684, ], ), ],