Skip to content

Commit

Permalink
edit to lif
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 1, 2024
1 parent c89ba39 commit a45dbbd
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ngclearn/components/neurons/spiking/LIFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import straight_through_estimator, triangular_estimator
from ngclearn.utils.surrogate_fx import arctan_estimator, triangular_estimator

@jit
def _update_times(t, s, tols):
Expand Down Expand Up @@ -237,6 +237,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.n_units = n_units

## set up surrogate function for spike emission
#self.spike_fx, self.d_spike_fx = arctan_estimator() #
self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()

## Compartment setup
Expand Down Expand Up @@ -264,10 +265,11 @@ def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
key, skey = random.split(key, 2)
## run one integration step for neuronal dynamics
j = j * R_m
surrogate = d_spike_fx(j, thr + thr_theta)
#surrogate = d_spike_fx(v, thr + thr_theta)
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
tau_m, v_rest, v_reset, v_decay, refract_T,
intgFlag)
surrogate = d_spike_fx(v, thr + thr_theta)
if tau_theta > 0.:
## run one integration step for threshold dynamics
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
Expand Down

0 comments on commit a45dbbd

Please sign in to comment.