Skip to content

Commit

Permalink
cleaned up raf
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Aug 8, 2024
1 parent dd49e5f commit 8882208
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions ngclearn/components/neurons/spiking/RAFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class RAFCell(JaxComponent):
The specific pair of differential equations that characterize this cell
are (for adjusting v and w, given current j, over time):
| tau_m * dv/dt = -(v - v_rest) + sharpV * exp((v - vT)/sharpV) - R_m * w + R_m * j
| tau_w * dw/dt = -w + (v - v_rest) * a
| where w = w + s * (w + b) [in the event of a spike]
| tau_m * dv/dt = omega * w + v * b
| tau_w * dw/dt = w * b - v * omega + j
| where omega is angular frequency (Hz) and b is exponential dampening factor
| --- Cell Input Compartments: ---
| j - electrical current input (takes in external signals)
Expand Down Expand Up @@ -93,13 +93,11 @@ class RAFCell(JaxComponent):
thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes) (Default: 5 mV)
v_rest: membrane resting potential (Default: -72 mV)
v_reset: membrane reset potential condition (Default: 0 mV)
b: oscillation dampening factor (Default: -1.)
v0: initial condition / reset for voltage (Default: -70 mV)
w_reset: reset condition for angular driver (Default: 0 mV)
w0: initial condition / reset for angular driver (Default: 0 mV)
b: oscillation dampening factor (Default: -1.)
integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
Expand All @@ -112,9 +110,9 @@ class RAFCell(JaxComponent):

# Define Functions
def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
omega=10., thr=5., v_rest=-72.,
v_reset=-75., w_reset=0., b=-1., v0=-70., w0=0.,
omega=10., thr=5., v_reset=0., w_reset=0., b=-1.,
integration_type="euler", batch_size=1, **kwargs):
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0.,
super().__init__(name, **kwargs)

## Integration properties
Expand All @@ -128,12 +126,9 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
self.omega = omega ## angular frequency
self.b = b ## dampening factor
## note: the smaller b is, the faster the oscillation dampens to resting state values
self.v_rest = v_rest
#self.v_rest = v_rest
self.v_reset = v_reset
self.w_reset = w_reset

self.v0 = v0 ## initial membrane potential/voltage condition
self.w0 = w0 ## initial w-parameter condition
self.thr = thr

## Layer Size Setup
Expand All @@ -150,8 +145,12 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
units="ms") ## time-of-last-spike

@staticmethod
def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, v_rest,
def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b,
v_reset, w_reset, intgFlag, j, v, w, tols):
## center variables before running dynamics
v = v - v_reset
w = w - w_reset
## continue with centered dynamics
j_ = j * resist_m
if intgFlag == 1: ## RK-2/midpoint
w_params = (j_, v, tau_w, omega, b)
Expand All @@ -165,9 +164,11 @@ def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, v_rest,
_, _v = step_euler(0., v, _dfv, dt, v_params)
s = _emit_spike(_v, thr)
## hyperpolarize/reset/snap variables
v = _v * (1. - s) + s * v_reset
w = _w * (1. - s) + s * w_reset

v = _v * (1. - s) + s #* v_reset
w = _w * (1. - s) + s #* w_reset
## artificially shift variables back to rest/reset values
v = v + v_reset
w = w + w_reset
tols = _update_times(t, s, tols)
return j, v, w, s, tols

Expand All @@ -180,11 +181,11 @@ def advance_state(self, j, v, w, s, tols):
self.tols.set(tols)

@staticmethod
def _reset(batch_size, n_units, v0, w0):
def _reset(batch_size, n_units, v_reset, w_reset):
restVals = jnp.zeros((batch_size, n_units))
j = restVals # None
v = restVals + v0
w = restVals + w0
v = restVals + v_reset
w = restVals + w_reset
s = restVals #+ 0
tols = restVals #+ 0
return j, v, w, s, tols
Expand Down

0 comments on commit 8882208

Please sign in to comment.