Skip to content

Commit

Permalink
Fix type check for inline expressions (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu authored Aug 17, 2023
1 parent 092f5fa commit b10be5f
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 23 deletions.
14 changes: 8 additions & 6 deletions doc/running.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,24 +199,26 @@ After generating and building the model code, a ``receptor_type`` entry is avail
neuron = nest.Create("iaf_psc_exp_multisynapse_neuron_nestml")
receptor_types = nest.GetStatus(neuron, "receptor_types")[0]
sg = nest.Create("spike_generator", params={"spike_times": [20., 80.]})
nest.Connect(sg, neuron, syn_spec={"receptor_type" : 1, "weight": 1000.})
nest.Connect(sg, neuron, syn_spec={"receptor_type" : receptor_types["SPIKES_1"], "weight": 1000.})
sg2 = nest.Create("spike_generator", params={"spike_times": [40., 60.]})
nest.Connect(sg2, neuron, syn_spec={"receptor_type" : 2, "weight": 1000.})
nest.Connect(sg2, neuron, syn_spec={"receptor_type" : receptor_types["SPIKES_2"], "weight": 1000.})
sg3 = nest.Create("spike_generator", params={"spike_times": [30., 70.]})
nest.Connect(sg3, neuron, syn_spec={"receptor_type" : 3, "weight": 500.})
nest.Connect(sg3, neuron, syn_spec={"receptor_type" : receptor_types["SPIKES_3"], "weight": 500.})
Note that in multisynapse neurons, receptor ports are numbered starting from 1.

We furthermore wish to record the synaptic currents ``I_kernel1``, ``I_kernel2`` and ``I_kernel3``. During code generation, one buffer is created for each combination of (kernel, spike input port) that appears in convolution statements. These buffers are named by joining together the name of the kernel with the name of the spike buffer using (by default) the string "__X__". The variables to be recorded are thus named as follows:

.. code-block:: python
mm = nest.Create('multimeter', params={'record_from': ['I_kernel1__X__spikes1',
'I_kernel2__X__spikes2',
'I_kernel3__X__spikes3'],
mm = nest.Create('multimeter', params={'record_from': ['I_kernel1__X__spikes_1',
'I_kernel2__X__spikes_2',
'I_kernel3__X__spikes_3'],
'interval': .1})
nest.Connect(mm, neuron)
Expand Down
2 changes: 1 addition & 1 deletion models/neurons/hh_moto_5ht.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ neuron hh_moto_5ht:
inline I_syn_exc pA = convolve(I_syn_ex, exc_spikes)
inline I_syn_inh pA = convolve(I_syn_in, inh_spikes)

inline E_Ca mV = ((1000.0 * R_const * T_current) / (2. * F_const)) * log10(Ca_out / Ca_in)
inline E_Ca mV = ((1000.0 * R_const * T_current) / (2. * F_const)) * log10(Ca_out / Ca_in) * mV

inline I_Na pA = g_Na * Act_m * Act_m * Act_m * Act_h * ( V_m - E_Na )
inline I_K pA = g_K_rect * Inact_n * Inact_n * Inact_n * Inact_n * ( V_m - E_K )
Expand Down
10 changes: 5 additions & 5 deletions models/neurons/hill_tononi.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ neuron hill_tononi:
recordable inline I_KNa pA = -KNa_g_peak * m_inf_KNa * ( V_m - KNa_E_rev )

# Low-thresh Ca current; member only to allow recording
recordable inline I_T pA = -T_g_peak * IT_m * IT_m * IT_h * ( V_m - T_E_rev )
recordable inline I_T pA = -T_g_peak / nS * IT_m / nS * IT_m / nS * IT_h * ( V_m - T_E_rev )

recordable inline I_h pA = -h_g_peak * Ih_m * ( V_m - h_E_rev )
recordable inline I_h pA = -h_g_peak / nS * Ih_m * ( V_m - h_E_rev )
# The spike current is only activate immediately after a spike.
inline I_spike mV = (g_spike) ? -( V_m - E_K ) / Tau_spike : 0
inline I_spike mV = (g_spike) ? -( V_m - E_K ) / Tau_spike * ms : 0 mV
V_m' = ( ( I_Na + I_K + I_syn + I_NaP + I_KNa + I_T + I_h + I_e + I_stim ) / Tau_m + I_spike * pA/(ms * mV) ) * s/nF

#############
Expand All @@ -96,8 +96,8 @@ neuron hill_tononi:
# I_KNa
inline D_influx_peak real = 0.025
inline tau_D real = 1250.0 # yes, 1.25 s
inline D_thresh mV = -10.0
inline D_slope mV = 5.0
inline D_thresh mV = -10.0 mV
inline D_slope mV = 5.0 mV
inline D_influx real = 1.0 / ( 1.0 + exp( -( V_m - D_thresh ) / D_slope ) )

Theta' = -( Theta - Theta_eq ) / Tau_theta
Expand Down
4 changes: 2 additions & 2 deletions models/neurons/terub_stn.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ neuron terub_stn:
inline k_Ca real = 22.5
inline k1 real = 15.0

inline I_exc_mod pA = -convolve(g_exc, exc_spikes) * V_m
inline I_inh_mod pA = convolve(g_inh, inh_spikes) * (V_m - E_gs)
inline I_exc_mod pA = -convolve(g_exc, exc_spikes) * V_m / mV
inline I_inh_mod pA = convolve(g_inh, inh_spikes) * (V_m - E_gs) / mV

inline tau_n ms = tau_n_0 + tau_n_1 / (1. + exp(-(V_m-theta_n_tau)/sigma_n_tau))
inline tau_h ms = tau_h_0 + tau_h_1 / (1. + exp(-(V_m-theta_h_tau)/sigma_h_tau))
Expand Down
17 changes: 15 additions & 2 deletions pynestml/cocos/co_co_illegal_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression

from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.meta_model.ast_declaration import ASTDeclaration
Expand Down Expand Up @@ -71,6 +72,19 @@ def visit_declaration(self, node):
TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression())
return

def visit_inline_expression(self, node):
"""
Visits a single inline expression and asserts that type of lhs is equal to type of rhs.
"""
assert isinstance(node, ASTInlineExpression)
lhs_type = node.get_data_type().get_type_symbol()
rhs_type = node.get_expression().type
if isinstance(rhs_type, ErrorTypeSymbol):
LoggingHelper.drop_missing_type_error(node)
return
if self.__types_do_not_match(lhs_type, rhs_type):
TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression())

def visit_assignment(self, node):
"""
Visits a single expression and assures that type(lhs) == type(rhs).
Expand Down Expand Up @@ -231,8 +245,7 @@ def visit_for_stmt(self, node):
Logger.log_message(code=code, message=message, error_position=node.get_start_from().get_source_position(),
log_level=LoggingLevel.ERROR)
elif not (from_type.equals(PredefinedTypes.get_integer_type())
or from_type.equals(
PredefinedTypes.get_real_type())):
or from_type.equals(PredefinedTypes.get_real_type())):
code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_integer_type(),
from_type)
Logger.log_message(code=code, message=message, error_position=node.get_start_from().get_source_position(),
Expand Down
4 changes: 2 additions & 2 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def print_declaration(self, node: ASTDeclaration) -> str:
ret += ","
ret += " " + self.print(node.get_data_type()) + " "
if node.has_size_parameter():
ret += "[" + node.get_size_parameter() + "] "
ret += "[" + self.print(node.get_size_parameter()) + "] "
if node.has_expression():
ret += "= " + self.print(node.get_expression())
if node.has_invariant():
Expand Down Expand Up @@ -370,7 +370,7 @@ def print_input_port(self, node: ASTInputPort) -> str:
if node.has_datatype():
ret += " " + self.print(node.get_datatype()) + " "
if node.has_size_parameter():
ret += "[" + node.get_size_parameter() + "]"
ret += "[" + self.print(node.get_size_parameter()) + "]"
ret += "<- "
if node.has_input_qualifiers():
for qual in node.get_input_qualifiers():
Expand Down
9 changes: 8 additions & 1 deletion pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ def construct_kernel_X_spike_buf_name(cls, kernel_var_name: str, spike_input_por

spike_input_port_name = spike_input_port.get_name()
if spike_input_port.has_vector_parameter():
spike_input_port_name += str(cls.get_numeric_vector_size(spike_input_port))
spike_input_port_name += "_" + str(cls.get_numeric_vector_size(spike_input_port))

return kernel_var_name.replace("$", "__DOLLAR") + "__X__" + spike_input_port_name + diff_order_symbol * order

Expand Down Expand Up @@ -1332,6 +1332,13 @@ def get_input_port_by_name(cls, input_blocks: List[ASTInputBlock], port_name: st
"""
for input_block in input_blocks:
for input_port in input_block.get_input_ports():
if input_port.has_size_parameter():
size_parameter = input_port.get_size_parameter()
if isinstance(size_parameter, ASTSimpleExpression):
size_parameter = size_parameter.get_numeric_literal()
port_name, port_index = port_name.split("_")
assert int(port_index) > 0
assert int(port_index) <= size_parameter
if input_port.name == port_name:
return input_port
return None
Expand Down
4 changes: 2 additions & 2 deletions tests/cocos_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_invalid_convolve_correctly_defined(self):
os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
'CoCoConvolveNotCorrectlyProvided.nestml'))
self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0],
LoggingLevel.ERROR)), 2)
LoggingLevel.ERROR)), 3)

def test_valid_convolve_correctly_defined(self):
Logger.set_logging_level(LoggingLevel.INFO)
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_invalid_convolve_correctly_parameterized(self):
os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
'CoCoConvolveNotCorrectlyParametrized.nestml'))
self.assertEqual(len(
Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], LoggingLevel.ERROR)), 1)
Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], LoggingLevel.ERROR)), 2)

def test_valid_convolve_correctly_parameterized(self):
Logger.set_logging_level(LoggingLevel.INFO)
Expand Down
2 changes: 1 addition & 1 deletion tests/nest_tests/nest_multisynapse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_multisynapse_with_vector_input_ports(self):
nest.Connect(sg3, neuron, syn_spec={"receptor_type": receptor_types["SPIKES_3"], "weight": 500., "delay": 0.1})

mm = nest.Create("multimeter", params={"record_from": [
"I_kernel1__X__spikes1", "I_kernel2__X__spikes2", "I_kernel3__X__spikes3"], "interval": 0.1})
"I_kernel1__X__spikes_1", "I_kernel2__X__spikes_2", "I_kernel3__X__spikes_3"], "interval": 0.1})
nest.Connect(mm, neuron)

vm_1 = nest.Create("voltmeter")
Expand Down
2 changes: 1 addition & 1 deletion tests/valid/CoCoConvolveNotCorrectlyProvided.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ along with NEST. If not, see <http://www.gnu.org/licenses/>.
neuron CoCoConvolveNotCorrectlyProvided:
equations:
kernel test = 10
inline testB pA = convolve(test, spikeExc) # convolve provided with a kernel and a spike input port, thus correct
inline testB pA = convolve(test, spikeExc) * pA # convolve provided with a kernel and a spike input port, thus correct

input:
spikeExc integer <- excitatory spike
Expand Down

0 comments on commit b10be5f

Please sign in to comment.