-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement early stopping #28
Implement early stopping #28
Conversation
3a68c33
to
33a6ef8
Compare
Pull request automatically marked stale! |
33a6ef8
to
6b0abea
Compare
When the |
n_iter_sim = 0 | ||
nest.Simulate(duration["total_offset"]) | ||
|
||
phase_label_previous = "" | ||
for k_iter in range(n_iter): | ||
if do_early_stopping and k_iter % n_validate_every == 0: | ||
error_val, n_iter_sim, phase_label_previous = run("validation", n_iter_sim, eta_test, phase_label_previous) | ||
|
||
if k_iter > 0 and error_val < stop_crit: | ||
errors_early_stop, n_iter_sim, phase_label_previous = run( | ||
"early-stopping", n_iter_sim, eta_test, phase_label_previous | ||
) | ||
if np.mean(errors_early_stop) < stop_crit: | ||
break | ||
|
||
run_iter = min(n_iter - k_iter, n_validate_every) | ||
_, n_iter_sim, phase_label_previous = run("training", n_iter_sim, eta_train, phase_label_previous) | ||
|
||
for _ in range(n_test): | ||
_, n_iter_sim, phase_label_previous = run("test", n_iter_sim, eta_test, phase_label_previous) | ||
|
||
nest.Simulate(steps["extension_sim"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the loop structure be reorganized to make the phase separation clearer? Currently, the only indication is the string labels, but they don’t seem obvious enough. Would a reorganization like the following be more effective?
def simulate_training_phase(n_iter_sim, eta_train, phase_label_previous):
run_iter = min(n_iter - k_iter, n_validate_every)
return run("training", n_iter_sim, eta_train, phase_label_previous)
def simulate_validation_phase(n_iter_sim, eta_test, phase_label_previous):
return run("validation", n_iter_sim, eta_test, phase_label_previous)
def simulate_early_stopping_phase(n_iter_sim, eta_test, phase_label_previous):
errors_early_stop, n_iter_sim, phase_label_previous = run("early-stopping", n_iter_sim, eta_test, phase_label_previous)
if np.mean(errors_early_stop) < stop_crit:
return True, n_iter_sim, phase_label_previous
return False, n_iter_sim, phase_label_previous
def simulate_test_phase(n_test, n_iter_sim, eta_test, phase_label_previous):
for _ in range(n_test):
run("test", n_iter_sim, eta_test, phase_label_previous)
n_iter_sim = 0
nest.Simulate(duration["total_offset"])
phase_label_previous = ""
for k_iter in range(n_iter):
# Validation phase and early stoping check
if do_early_stopping and k_iter % n_validate_every == 0:
error_val, n_iter_sim, phase_label_previous = simulate_validation_phase(n_iter_sim, eta_test, phase_label_previous)
if k_iter > 0 and error_val < stop_crit:
should_stop, n_iter_sim, phase_label_previous = simulate_early_stopping_phase(n_iter_sim, eta_test, phase_label_previous)
if should_stop:
print(f"Early stopping at iteration {k_iter}")
break
# Training phase
n_iter_sim, phase_label_previous = simulate_training_phase(n_iter_sim, eta_train, phase_label_previous)
# Test phase
simulate_test_phase(n_test, n_iter_sim, eta_test, phase_label_previous)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea! Please, see 4c13272
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks great. The compartmentalization makes it look much cleaner and more organized. I'm sure this should make maintenance and modification easier. In particular I like the new run_early_stopping
function, which isolates the core early stopping logic into few, easy-to-follow lines.
Co-authored-by: JesusEV <43375826+JesusEV@users.noreply.github.com>
Co-authored-by: JesusEV <43375826+JesusEV@users.noreply.github.com>
6b0abea
to
4c13272
Compare
Changing parameter values for astrocyte_lr_1994
This PR replaces PR #2 and implements the early-stopping algorithm as described in the corresponding evidence accumulation task implemented in TensorFlow. The only difference is that here the early-stopping criterion is evaluated after each validation step ( e.g., every ten iterations) and not after every iteration as in the TensorFlow implementation. Since the early-stopping is assessed in the first instance with the newest validation value, evaluating the early-stopping for ten iterations with the same validation result as in the TensorFlow implementation seems wasteful.
Currently, the NEST losses match the TF losses for the iterations where there has not been a weight update yet; those are the first two iterations.
To see if these deviations are due to an extra spike that results from numerical differences between TensorFlow and NEST, in the following experiment, a recurrent neuron (index 82) was forced to emit an extra spike at t = 4000, which is in the second iteration. This perturbation causes a deviation in the loss's 11th decimal digit.
Therefore, probably the reason for the deviation between TF and NEST is that the gradients were not computed correctly.