From b10be5f22d382010d1bb236785b06c5ad9ee64e7 Mon Sep 17 00:00:00 2001 From: Pooja Babu <75320801+pnbabu@users.noreply.github.com> Date: Thu, 17 Aug 2023 17:24:34 +0200 Subject: [PATCH] Fix type check for inline expressions (#945) --- doc/running.rst | 14 ++++++++------ models/neurons/hh_moto_5ht.nestml | 2 +- models/neurons/hill_tononi.nestml | 10 +++++----- models/neurons/terub_stn.nestml | 4 ++-- pynestml/cocos/co_co_illegal_expression.py | 17 +++++++++++++++-- .../codegeneration/printers/nestml_printer.py | 4 ++-- pynestml/utils/ast_utils.py | 9 ++++++++- tests/cocos_test.py | 4 ++-- tests/nest_tests/nest_multisynapse_test.py | 2 +- .../CoCoConvolveNotCorrectlyProvided.nestml | 2 +- 10 files changed, 45 insertions(+), 23 deletions(-) diff --git a/doc/running.rst b/doc/running.rst index 498785410..39572474e 100644 --- a/doc/running.rst +++ b/doc/running.rst @@ -199,14 +199,16 @@ After generating and building the model code, a ``receptor_type`` entry is avail neuron = nest.Create("iaf_psc_exp_multisynapse_neuron_nestml") + receptor_types = nest.GetStatus(neuron, "receptor_types")[0] + sg = nest.Create("spike_generator", params={"spike_times": [20., 80.]}) - nest.Connect(sg, neuron, syn_spec={"receptor_type" : 1, "weight": 1000.}) + nest.Connect(sg, neuron, syn_spec={"receptor_type" : receptor_types["SPIKES_1"], "weight": 1000.}) sg2 = nest.Create("spike_generator", params={"spike_times": [40., 60.]}) - nest.Connect(sg2, neuron, syn_spec={"receptor_type" : 2, "weight": 1000.}) + nest.Connect(sg2, neuron, syn_spec={"receptor_type" : receptor_types["SPIKES_2"], "weight": 1000.}) sg3 = nest.Create("spike_generator", params={"spike_times": [30., 70.]}) - nest.Connect(sg3, neuron, syn_spec={"receptor_type" : 3, "weight": 500.}) + nest.Connect(sg3, neuron, syn_spec={"receptor_type" : receptor_types["SPIKES_3"], "weight": 500.}) Note that in multisynapse neurons, receptor ports are numbered starting from 1. @@ -214,9 +216,9 @@ We furthermore wish to record the synaptic currents ``I_kernel1``, ``I_kernel2`` .. code-block:: python - mm = nest.Create('multimeter', params={'record_from': ['I_kernel1__X__spikes1', - 'I_kernel2__X__spikes2', - 'I_kernel3__X__spikes3'], + mm = nest.Create('multimeter', params={'record_from': ['I_kernel1__X__spikes_1', + 'I_kernel2__X__spikes_2', + 'I_kernel3__X__spikes_3'], 'interval': .1}) nest.Connect(mm, neuron) diff --git a/models/neurons/hh_moto_5ht.nestml b/models/neurons/hh_moto_5ht.nestml index 14989afce..b51dbe9d0 100644 --- a/models/neurons/hh_moto_5ht.nestml +++ b/models/neurons/hh_moto_5ht.nestml @@ -49,7 +49,7 @@ neuron hh_moto_5ht: inline I_syn_exc pA = convolve(I_syn_ex, exc_spikes) inline I_syn_inh pA = convolve(I_syn_in, inh_spikes) - inline E_Ca mV = ((1000.0 * R_const * T_current) / (2. * F_const)) * log10(Ca_out / Ca_in) + inline E_Ca mV = ((1000.0 * R_const * T_current) / (2. * F_const)) * log10(Ca_out / Ca_in) * mV inline I_Na pA = g_Na * Act_m * Act_m * Act_m * Act_h * ( V_m - E_Na ) inline I_K pA = g_K_rect * Inact_n * Inact_n * Inact_n * Inact_n * ( V_m - E_K ) diff --git a/models/neurons/hill_tononi.nestml b/models/neurons/hill_tononi.nestml index e0bf5acb8..da1a61894 100644 --- a/models/neurons/hill_tononi.nestml +++ b/models/neurons/hill_tononi.nestml @@ -79,11 +79,11 @@ neuron hill_tononi: recordable inline I_KNa pA = -KNa_g_peak * m_inf_KNa * ( V_m - KNa_E_rev ) # Low-thresh Ca current; member only to allow recording - recordable inline I_T pA = -T_g_peak * IT_m * IT_m * IT_h * ( V_m - T_E_rev ) + recordable inline I_T pA = -T_g_peak / nS * IT_m / nS * IT_m / nS * IT_h * ( V_m - T_E_rev ) - recordable inline I_h pA = -h_g_peak * Ih_m * ( V_m - h_E_rev ) + recordable inline I_h pA = -h_g_peak / nS * Ih_m * ( V_m - h_E_rev ) # The spike current is only activate immediately after a spike. - inline I_spike mV = (g_spike) ? -( V_m - E_K ) / Tau_spike : 0 + inline I_spike mV = (g_spike) ? -( V_m - E_K ) / Tau_spike * ms : 0 mV V_m' = ( ( I_Na + I_K + I_syn + I_NaP + I_KNa + I_T + I_h + I_e + I_stim ) / Tau_m + I_spike * pA/(ms * mV) ) * s/nF ############# @@ -96,8 +96,8 @@ neuron hill_tononi: # I_KNa inline D_influx_peak real = 0.025 inline tau_D real = 1250.0 # yes, 1.25 s - inline D_thresh mV = -10.0 - inline D_slope mV = 5.0 + inline D_thresh mV = -10.0 mV + inline D_slope mV = 5.0 mV inline D_influx real = 1.0 / ( 1.0 + exp( -( V_m - D_thresh ) / D_slope ) ) Theta' = -( Theta - Theta_eq ) / Tau_theta diff --git a/models/neurons/terub_stn.nestml b/models/neurons/terub_stn.nestml index c8a48bd3e..156cb2d0e 100644 --- a/models/neurons/terub_stn.nestml +++ b/models/neurons/terub_stn.nestml @@ -81,8 +81,8 @@ neuron terub_stn: inline k_Ca real = 22.5 inline k1 real = 15.0 - inline I_exc_mod pA = -convolve(g_exc, exc_spikes) * V_m - inline I_inh_mod pA = convolve(g_inh, inh_spikes) * (V_m - E_gs) + inline I_exc_mod pA = -convolve(g_exc, exc_spikes) * V_m / mV + inline I_inh_mod pA = convolve(g_inh, inh_spikes) * (V_m - E_gs) / mV inline tau_n ms = tau_n_0 + tau_n_1 / (1. + exp(-(V_m-theta_n_tau)/sigma_n_tau)) inline tau_h ms = tau_h_0 + tau_h_1 / (1. + exp(-(V_m-theta_h_tau)/sigma_h_tau)) diff --git a/pynestml/cocos/co_co_illegal_expression.py b/pynestml/cocos/co_co_illegal_expression.py index 71e7c2cb0..a1598c08d 100644 --- a/pynestml/cocos/co_co_illegal_expression.py +++ b/pynestml/cocos/co_co_illegal_expression.py @@ -18,6 +18,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.meta_model.ast_declaration import ASTDeclaration @@ -71,6 +72,19 @@ def visit_declaration(self, node): TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression()) return + def visit_inline_expression(self, node): + """ + Visits a single inline expression and asserts that type of lhs is equal to type of rhs. + """ + assert isinstance(node, ASTInlineExpression) + lhs_type = node.get_data_type().get_type_symbol() + rhs_type = node.get_expression().type + if isinstance(rhs_type, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + if self.__types_do_not_match(lhs_type, rhs_type): + TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression()) + def visit_assignment(self, node): """ Visits a single expression and assures that type(lhs) == type(rhs). @@ -231,8 +245,7 @@ def visit_for_stmt(self, node): Logger.log_message(code=code, message=message, error_position=node.get_start_from().get_source_position(), log_level=LoggingLevel.ERROR) elif not (from_type.equals(PredefinedTypes.get_integer_type()) - or from_type.equals( - PredefinedTypes.get_real_type())): + or from_type.equals(PredefinedTypes.get_real_type())): code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_integer_type(), from_type) Logger.log_message(code=code, message=message, error_position=node.get_start_from().get_source_position(), diff --git a/pynestml/codegeneration/printers/nestml_printer.py b/pynestml/codegeneration/printers/nestml_printer.py index 2e213b566..c822940b5 100644 --- a/pynestml/codegeneration/printers/nestml_printer.py +++ b/pynestml/codegeneration/printers/nestml_printer.py @@ -253,7 +253,7 @@ def print_declaration(self, node: ASTDeclaration) -> str: ret += "," ret += " " + self.print(node.get_data_type()) + " " if node.has_size_parameter(): - ret += "[" + node.get_size_parameter() + "] " + ret += "[" + self.print(node.get_size_parameter()) + "] " if node.has_expression(): ret += "= " + self.print(node.get_expression()) if node.has_invariant(): @@ -370,7 +370,7 @@ def print_input_port(self, node: ASTInputPort) -> str: if node.has_datatype(): ret += " " + self.print(node.get_datatype()) + " " if node.has_size_parameter(): - ret += "[" + node.get_size_parameter() + "]" + ret += "[" + self.print(node.get_size_parameter()) + "]" ret += "<- " if node.has_input_qualifiers(): for qual in node.get_input_qualifiers(): diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 6fba498a7..9a31e9afe 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -1239,7 +1239,7 @@ def construct_kernel_X_spike_buf_name(cls, kernel_var_name: str, spike_input_por spike_input_port_name = spike_input_port.get_name() if spike_input_port.has_vector_parameter(): - spike_input_port_name += str(cls.get_numeric_vector_size(spike_input_port)) + spike_input_port_name += "_" + str(cls.get_numeric_vector_size(spike_input_port)) return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + spike_input_port_name + diff_order_symbol * order @@ -1332,6 +1332,13 @@ def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: st """ for input_block in input_blocks: for input_port in input_block.get_input_ports(): + if input_port.has_size_parameter(): + size_parameter = input_port.get_size_parameter() + if isinstance(size_parameter, ASTSimpleExpression): + size_parameter = size_parameter.get_numeric_literal() + port_name, port_index = port_name.split("_") + assert int(port_index) > 0 + assert int(port_index) <= size_parameter if input_port.name == port_name: return input_port return None diff --git a/tests/cocos_test.py b/tests/cocos_test.py index 9d2ad6935..f4a17f6c0 100644 --- a/tests/cocos_test.py +++ b/tests/cocos_test.py @@ -407,7 +407,7 @@ def test_invalid_convolve_correctly_defined(self): os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyProvided.nestml')) self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], - LoggingLevel.ERROR)), 2) + LoggingLevel.ERROR)), 3) def test_valid_convolve_correctly_defined(self): Logger.set_logging_level(LoggingLevel.INFO) @@ -487,7 +487,7 @@ def test_invalid_convolve_correctly_parameterized(self): os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyParametrized.nestml')) self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], LoggingLevel.ERROR)), 1) + Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], LoggingLevel.ERROR)), 2) def test_valid_convolve_correctly_parameterized(self): Logger.set_logging_level(LoggingLevel.INFO) diff --git a/tests/nest_tests/nest_multisynapse_test.py b/tests/nest_tests/nest_multisynapse_test.py index 68ca9c78f..5fa08136f 100644 --- a/tests/nest_tests/nest_multisynapse_test.py +++ b/tests/nest_tests/nest_multisynapse_test.py @@ -147,7 +147,7 @@ def test_multisynapse_with_vector_input_ports(self): nest.Connect(sg3, neuron, syn_spec={"receptor_type": receptor_types["SPIKES_3"], "weight": 500., "delay": 0.1}) mm = nest.Create("multimeter", params={"record_from": [ - "I_kernel1__X__spikes1", "I_kernel2__X__spikes2", "I_kernel3__X__spikes3"], "interval": 0.1}) + "I_kernel1__X__spikes_1", "I_kernel2__X__spikes_2", "I_kernel3__X__spikes_3"], "interval": 0.1}) nest.Connect(mm, neuron) vm_1 = nest.Create("voltmeter") diff --git a/tests/valid/CoCoConvolveNotCorrectlyProvided.nestml b/tests/valid/CoCoConvolveNotCorrectlyProvided.nestml index 2fd28bd43..a792037ff 100644 --- a/tests/valid/CoCoConvolveNotCorrectlyProvided.nestml +++ b/tests/valid/CoCoConvolveNotCorrectlyProvided.nestml @@ -34,7 +34,7 @@ along with NEST. If not, see . neuron CoCoConvolveNotCorrectlyProvided: equations: kernel test = 10 - inline testB pA = convolve(test, spikeExc) # convolve provided with a kernel and a spike input port, thus correct + inline testB pA = convolve(test, spikeExc) * pA # convolve provided with a kernel and a spike input port, thus correct input: spikeExc integer <- excitatory spike