Skip to content

Commit

Permalink
improve handling of NEST synaptic delay parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Sep 15, 2023
1 parent f221e5c commit 01a1e47
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ public:
}

{%- endif %}

};


Expand Down Expand Up @@ -333,7 +334,7 @@ private:
{%- for variable_symbol in synapse.get_parameter_symbols() %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if not isHomogeneous and not (variable_symbol.namespace_decorators and variable_symbol.namespace_decorators | length > 0 and variable_symbol.namespace_decorators[0] == "delay") %}
{%- if not isHomogeneous and not (variable_symbol.namespace_decorators and variable_symbol.namespace_decorators | length > 0 and variable_symbol.namespace_decorators["nest"] == "delay") %}
{%- include 'directives/MemberDeclaration.jinja2' %}
{%- else %}
// N.B. the parameter `{{ printer.print(variable) }}` is defined in the common properties class
Expand Down Expand Up @@ -1054,9 +1055,8 @@ void
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- if (not is_delta_kernel(synapse.get_kernel_by_name(variable_symbol.get_symbol_name()))) and (not variable_symbol.is_internals()) %}
{%- if variable_symbol.get_namespace_decorator("nest")|length > 0 %}
// special treatment for variable marked with @nest::name decorator
{%- set nest_namespace_name = variable_symbol.get_namespace_decorator("nest") %}
def<{{declarations.print_variable_type(variable_symbol)}}>(__d, names::{{nest_namespace_name}}, get_{{printer_no_origin.print(variable)}}());
def<{{declarations.print_variable_type(variable_symbol)}}>(__d, names::{{nest_namespace_name}}, get_{{printer_no_origin.print(variable)}}()); // special treatment for variable marked with @nest::name decorator
{%- else %}
{%- include "directives/WriteInDictionary.jinja2" %}
{%- endif %}
Expand Down Expand Up @@ -1194,6 +1194,8 @@ updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, nest::
{%- set namespaceName = variable_symbol.get_namespace_decorator("nest") %}
{%- if namespaceName | length > 0 %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
if ( __d->known(nest::names::{{ namespaceName }}) )
{
{%- if namespaceName == "delay" %}
// special treatment of NEST delay
set_delay(tmp_delay);
Expand All @@ -1203,6 +1205,7 @@ updateValue<{{ declarations.print_variable_type(variable_symbol) }}>(__d, nest::
{%- else %}
{{ raise("Unknown NEST namespace name: " + namespaceName) }}
{%- endif %}
}
{%- endif %}
{%- endfor %}

Expand Down Expand Up @@ -1281,20 +1284,8 @@ template < typename targetidentifierT >
{%- for variable_symbol in synapse.get_parameter_symbols() %}
{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- set has_namespace_decorator = (variable_symbol.namespace_decorators and variable_symbol.namespace_decorators | length > 0) %}
{%- if not isHomogeneous %}
{%- if not has_namespace_decorator %}
{%- if not isHomogeneous and not (variable_symbol.namespace_decorators and variable_symbol.namespace_decorators | length > 0 and variable_symbol.namespace_decorators["nest"] == "delay") %}
{{ printer.print(variable) }} = rhs.{{ printer.print(variable) }};
{%- else %}
{%- set namespaceName = variable_symbol.get_namespace_decorator("nest") %}
{%- if namespaceName == "delay" %}
set_delay(rhs.get_delay());
{%- elif namespaceName == "weight" %}
set_weight(rhs.get_weight());
{%- else %}
{{ raise("Unknown NEST namespace name: " + namespaceName) }}
{%- endif %}
{%- endif %}
{%- endif %}
{%- endfor %}

Expand All @@ -1304,8 +1295,6 @@ template < typename targetidentifierT >
{{ printer.print(variable) }} = rhs.{{ printer.print(variable) }};
{%- endfor %}

//weight_ = get_named_parameter<double>(names::weight);
//set_weight( *rhs.weight_ );
{%- if vt_ports is defined and vt_ports|length > 0 %}
t_last_update_ = rhs.t_last_update_;
{%- endif %}
Expand Down
10 changes: 3 additions & 7 deletions tests/nest_tests/stdp_triplet_synapse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def run_reference_simulation(syn_opts,
}

for spk_time in np.unique(times_spikes_syn_persp):
import logging
logging.warning("XXX: TODO: before_increment values here are all wrong")

if spk_time in times_spikes_post_syn_persp:
# print("Post spike --> facilitation")
print("\tgetting pre trace r1")
Expand All @@ -132,7 +129,7 @@ def run_reference_simulation(syn_opts,
weight = np.clip(weight + r1 * (syn_opts["A2_plus"] + syn_opts["A3_plus"]
* o2), a_min=syn_opts["w_min"], a_max=syn_opts["w_max"])
# print("\tnew weight = " + str(weight))
print("[NESTML] stdp_connection: facilitating from " + str(old_weight) + " to "
print("[REF] stdp_connection: facilitating from " + str(old_weight) + " to "
+ str(weight) + " with pre tr = " + str(r1) + ", post tr = " + str(o2))

if spk_time in times_spikes_pre:
Expand All @@ -150,7 +147,7 @@ def run_reference_simulation(syn_opts,
weight = np.clip(weight - o1 * (syn_opts["A2_minus"] + syn_opts["A3_minus"]
* r2), a_min=syn_opts["w_min"], a_max=syn_opts["w_max"])
# print("\tnew weight = " + str(weight))
print("[NESTML] stdp_connection: depressing from " + str(old_weight) + " to "
print("[REF] stdp_connection: depressing from " + str(old_weight) + " to "
+ str(weight) + " with pre tr = " + str(r2) + ", post tr = " + str(o1))

log[spk_time] = {"weight": weight}
Expand Down Expand Up @@ -231,7 +228,6 @@ def run_nest_simulation(neuron_model_name,
_syn_opts["Wmax"] = _syn_opts.pop("w_max")
_syn_opts["Wmin"] = _syn_opts.pop("w_min")
_syn_opts["weight"] = _syn_opts.pop("w_init")
_syn_opts.pop("delay")
nest.CopyModel(synapse_model_name,
synapse_model_name + "_rec",
{"weight_recorder": weight_recorder_E[0]})
Expand Down Expand Up @@ -265,6 +261,7 @@ def run_nest_simulation(neuron_model_name,
events = nest.GetStatus(weight_recorder_E, "events")[0]
times_weights = events["times"]
weight_simulation = events["weights"]

return times_weights, weight_simulation, gid_pre, gid_post, times_spikes, senders_spikes, sim_time


Expand All @@ -287,7 +284,6 @@ def compare_results(timevec, weight_reference, times_weights, weight_simulation)
w_ref_vec.append(w_ref)

np.testing.assert_allclose(weight_simulation, w_ref_vec, atol=1E-6, rtol=1E-6)
print("Test passed!")


def plot_comparison(syn_opts, times_spikes_pre, times_spikes_post, times_spikes_post_syn_persp, timevec, weight_reference, times_weights, weight_simulation, sim_time):
Expand Down

0 comments on commit 01a1e47

Please sign in to comment.