diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py index d4dfb844..fae6b08a 100755 --- a/ngclearn/utils/diffeq/ode_utils.py +++ b/ngclearn/utils/diffeq/ode_utils.py @@ -55,7 +55,8 @@ def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine _x = x * x_scale + dx_dt * dt return _t, _x -@partial(jit, static_argnums=(2, )) + + def step_euler(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via the Euler method, i.e., a @@ -83,7 +84,7 @@ def step_euler(t, x, dfx, dt, params, x_scale=1.): _t, _x = _step_forward(t, x, dx_dt, dt, x_scale) return _t, _x -@partial(jit, static_argnums=(2, )) + def step_heun(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via Heun's method, i.e., a @@ -124,7 +125,7 @@ def step_heun(t, x, dfx, dt, params, x_scale=1.): _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale) return _t, _x -@partial(jit, static_argnums=(2, )) + def step_rk2(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via the midpoint method, i.e., a @@ -165,7 +166,7 @@ def step_rk2(t, x, dfx, dt, params, x_scale=1.): -@partial(jit, static_argnums=(2, )) + def step_rk4(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via the midpoint method, i.e., a @@ -211,7 +212,7 @@ def step_rk4(t, x, dfx, dt, params, x_scale=1.): _t, _x = _step_forward(t, x, _dx_dt / 6, dt, x_scale) return _t, _x -@partial(jit, static_argnums=(2, )) + def step_ralston(t, x, dfx, dt, params, x_scale=1.): """ Iteratively integrates one step forward via Ralston's method, i.e., a