Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement early stopping #28

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions models/weight_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace nest
WeightOptimizerCommonProperties::WeightOptimizerCommonProperties()
: batch_size_( 1 )
, eta_( 1e-4 )
, eta_first_( 1e-4 )
, n_eta_change_( 0 )
, Wmin_( -100.0 )
, Wmax_( 100.0 )
, optimize_each_step_( true )
Expand All @@ -43,6 +45,8 @@ WeightOptimizerCommonProperties::WeightOptimizerCommonProperties()
WeightOptimizerCommonProperties::WeightOptimizerCommonProperties( const WeightOptimizerCommonProperties& cp )
: batch_size_( cp.batch_size_ )
, eta_( cp.eta_ )
, eta_first_( cp.eta_first_ )
, n_eta_change_( cp.n_eta_change_ )
, Wmin_( cp.Wmin_ )
, Wmax_( cp.Wmax_ )
, optimize_each_step_( cp.optimize_each_step_ )
Expand Down Expand Up @@ -77,6 +81,16 @@ WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d )
{
throw BadProperty( "Learning rate eta ≥ 0 required." );
}

if ( new_eta != eta_ )
{
if ( n_eta_change_ == 0 )
{
eta_first_ = new_eta;
}
n_eta_change_ += 1;
}

eta_ = new_eta;

double new_Wmin = Wmin_;
Expand All @@ -96,6 +110,8 @@ WeightOptimizerCommonProperties::set_status( const DictionaryDatum& d )
WeightOptimizer::WeightOptimizer()
: sum_gradients_( 0.0 )
, optimization_step_( 1 )
, eta_( 1e-4 )
, n_optimize_( 0 )
{
}

Expand All @@ -110,11 +126,15 @@ WeightOptimizer::set_status( const DictionaryDatum& d )
}

double
WeightOptimizer::optimized_weight( const WeightOptimizerCommonProperties& cp,
WeightOptimizer::optimized_weight( WeightOptimizerCommonProperties& cp,
const size_t idx_current_update,
const double gradient,
double weight )
{
if ( cp.n_eta_change_ != 0 and n_optimize_ == 0 )
{
eta_ = cp.eta_first_;
}
sum_gradients_ += gradient;

if ( optimization_step_ == 0 )
Expand All @@ -127,6 +147,8 @@ WeightOptimizer::optimized_weight( const WeightOptimizerCommonProperties& cp,
{
sum_gradients_ /= cp.batch_size_;
weight = std::max( cp.Wmin_, std::min( optimize_( cp, weight, current_optimization_step ), cp.Wmax_ ) );
eta_ = cp.eta_;
n_optimize_ += 1;
optimization_step_ = current_optimization_step;
}
return weight;
Expand All @@ -152,7 +174,7 @@ WeightOptimizerGradientDescent::WeightOptimizerGradientDescent()
double
WeightOptimizerGradientDescent::optimize_( const WeightOptimizerCommonProperties& cp, double weight, size_t )
{
weight -= cp.eta_ * sum_gradients_;
weight -= eta_ * sum_gradients_;
sum_gradients_ = 0.0;
return weight;
}
Expand Down Expand Up @@ -251,7 +273,7 @@ WeightOptimizerAdam::optimize_( const WeightOptimizerCommonProperties& cp,
beta_1_power_ *= acp.beta_1_;
beta_2_power_ *= acp.beta_2_;

const double alpha = cp.eta_ * std::sqrt( 1.0 - beta_2_power_ ) / ( 1.0 - beta_1_power_ );
const double alpha = eta_ * std::sqrt( 1.0 - beta_2_power_ ) / ( 1.0 - beta_1_power_ );

m_ = acp.beta_1_ * m_ + ( 1.0 - acp.beta_1_ ) * sum_gradients_;
v_ = acp.beta_2_ * v_ + ( 1.0 - acp.beta_2_ ) * sum_gradients_ * sum_gradients_;
Expand Down
16 changes: 14 additions & 2 deletions models/weight_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,15 @@ class WeightOptimizerCommonProperties
//! Size of an optimization batch.
size_t batch_size_;

//! Learning rate.
//! Learning rate common to all synapses.
double eta_;

//! First learning rate that differs from the default.
double eta_first_;

//! Number of changes to the learning rate.
long n_eta_change_;

//! Minimal value for synaptic weight.
double Wmin_;

Expand Down Expand Up @@ -252,7 +258,7 @@ class WeightOptimizer
virtual void set_status( const DictionaryDatum& d );

//! Return optimized weight based on current weight.
double optimized_weight( const WeightOptimizerCommonProperties& cp,
double optimized_weight( WeightOptimizerCommonProperties& cp,
const size_t idx_current_update,
const double gradient,
double weight );
Expand All @@ -266,6 +272,12 @@ class WeightOptimizer

//! Current optimization step, whereby optimization happens every batch_size_ steps.
size_t optimization_step_;

//! Learning rate private to the synapse.
double eta_;

//! Number of optimizations.
long n_optimize_;
};

/**
Expand Down
Loading
Loading