Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement early stopping #28

Merged
merged 6 commits into from
Sep 25, 2024

Conversation

akorgor
Copy link
Collaborator

@akorgor akorgor commented May 29, 2024

This PR replaces PR #2 and implements the early-stopping algorithm as described in the corresponding evidence accumulation task implemented in TensorFlow. The only difference is that here the early-stopping criterion is evaluated after each validation step ( e.g., every ten iterations) and not after every iteration as in the TensorFlow implementation. Since the early-stopping is assessed in the first instance with the newest validation value, evaluating the early-stopping for ten iterations with the same validation result as in the TensorFlow implementation seems wasteful.

Currently, the NEST losses match the TF losses for the iterations where there has not been a weight update yet; those are the first two iterations.

NEST:
0.74115255000619 validation
0.75886758496570 training
0.64575804540432 training
0.65036521347625 training
0.73954336799350 training
0.64857381599914 training
0.63293547882357 test
0.74871743652812 test
0.63259857933630 test
0.65171656508917 test
TF:
0.74115252494812 validation
0.75886762142181 training
0.64488172531128 training
0.63414341211319 training
0.74553966522217 training
0.65522724390030 training
0.62222802639008 test
0.75369793176651 test
0.63924121856689 test
0.65074861049652 test

To see if these deviations are due to an extra spike that results from numerical differences between TensorFlow and NEST, in the following experiment, a recurrent neuron (index 82) was forced to emit an extra spike at t = 4000, which is in the second iteration. This perturbation causes a deviation in the loss's 11th decimal digit.

NEST (perturbed):
0.74115255000619 validation
0.75886758494402 training
0.64732471934116 training
0.65337249687406 training
0.73455409141632 training
0.64412005064290 training
0.64022872139437 test
0.74554388366270 test
0.62969715145448 test
0.64486264066109 test

Therefore, probably the reason for the deviation between TF and NEST is that the gradients were not computed correctly.

  • investigate reason for deviation between TF and NEST

@akorgor akorgor requested a review from JesusEV May 29, 2024 08:31
Copy link

Pull request automatically marked stale!

@akorgor
Copy link
Collaborator Author

akorgor commented Sep 22, 2024

When the do_early_stopping flag is set to False the losses of the original script without the early stopping framework are recovered.

@github-actions github-actions bot removed the stale label Sep 23, 2024
Comment on lines 681 to 702
n_iter_sim = 0
nest.Simulate(duration["total_offset"])

phase_label_previous = ""
for k_iter in range(n_iter):
if do_early_stopping and k_iter % n_validate_every == 0:
error_val, n_iter_sim, phase_label_previous = run("validation", n_iter_sim, eta_test, phase_label_previous)

if k_iter > 0 and error_val < stop_crit:
errors_early_stop, n_iter_sim, phase_label_previous = run(
"early-stopping", n_iter_sim, eta_test, phase_label_previous
)
if np.mean(errors_early_stop) < stop_crit:
break

run_iter = min(n_iter - k_iter, n_validate_every)
_, n_iter_sim, phase_label_previous = run("training", n_iter_sim, eta_train, phase_label_previous)

for _ in range(n_test):
_, n_iter_sim, phase_label_previous = run("test", n_iter_sim, eta_test, phase_label_previous)

nest.Simulate(steps["extension_sim"])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the loop structure be reorganized to make the phase separation clearer? Currently, the only indication is the string labels, but they don’t seem obvious enough. Would a reorganization like the following be more effective?

def simulate_training_phase(n_iter_sim, eta_train, phase_label_previous):
    run_iter = min(n_iter - k_iter, n_validate_every)
    return run("training", n_iter_sim, eta_train, phase_label_previous)


def simulate_validation_phase(n_iter_sim, eta_test, phase_label_previous):
    return run("validation", n_iter_sim, eta_test, phase_label_previous)


def simulate_early_stopping_phase(n_iter_sim, eta_test, phase_label_previous):
    errors_early_stop, n_iter_sim, phase_label_previous = run("early-stopping", n_iter_sim, eta_test, phase_label_previous)
    if np.mean(errors_early_stop) < stop_crit:
        return True, n_iter_sim, phase_label_previous
    return False, n_iter_sim, phase_label_previous


def simulate_test_phase(n_test, n_iter_sim, eta_test, phase_label_previous):
    for _ in range(n_test):
        run("test", n_iter_sim, eta_test, phase_label_previous)


n_iter_sim = 0
nest.Simulate(duration["total_offset"])
phase_label_previous = ""

for k_iter in range(n_iter):
    # Validation phase and early stoping check
    if do_early_stopping and k_iter % n_validate_every == 0:
        error_val, n_iter_sim, phase_label_previous = simulate_validation_phase(n_iter_sim, eta_test, phase_label_previous)

        if k_iter > 0 and error_val < stop_crit:
            should_stop, n_iter_sim, phase_label_previous = simulate_early_stopping_phase(n_iter_sim, eta_test, phase_label_previous)
            if should_stop:
                print(f"Early stopping at iteration {k_iter}")
                break

    # Training phase
    n_iter_sim, phase_label_previous = simulate_training_phase(n_iter_sim, eta_train, phase_label_previous)

# Test phase
simulate_test_phase(n_test, n_iter_sim, eta_test, phase_label_previous)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea! Please, see 4c13272

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great. The compartmentalization makes it look much cleaner and more organized. I'm sure this should make maintenance and modification easier. In particular I like the new run_early_stopping function, which isolates the core early stopping logic into few, easy-to-follow lines.

@akorgor akorgor merged commit 3b1b930 into JesusEV:eprop_bio_feature Sep 25, 2024
17 of 19 checks passed
@akorgor akorgor deleted the feat_early-stopping branch September 25, 2024 14:32
akorgor pushed a commit that referenced this pull request Sep 27, 2024
Changing parameter values for astrocyte_lr_1994
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants