From 9f1657d9df19f65a035d2132c7b77a78bcd40615 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Tue, 25 Jul 2023 03:25:58 -0700 Subject: [PATCH 1/2] fix type derivation of kernel buffers --- .../codegeneration/nest_code_generator.py | 6 ++--- pynestml/utils/ast_utils.py | 22 ++++++++++++++----- pynestml/utils/messages.py | 2 +- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 0aa0a48a1..eee458891 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -841,10 +841,8 @@ def get_spike_update_expressions(self, neuron: ASTNeuron, kernel_buffers, solver for kernel_var in kernel.get_variables(): for var_order in range(ASTUtils.get_kernel_var_order_from_ode_toolbox_result(kernel_var.get_name(), solver_dicts)): - kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name( - kernel_var.get_name(), spike_input_port, var_order) - expr = ASTUtils.get_initial_value_from_ode_toolbox_result( - kernel_spike_buf_name, solver_dicts) + kernel_spike_buf_name = ASTUtils.construct_kernel_X_spike_buf_name(kernel_var.get_name(), spike_input_port, var_order) + expr = ASTUtils.get_initial_value_from_ode_toolbox_result(kernel_spike_buf_name, solver_dicts) assert expr is not None, "Initial value not found for kernel " + kernel_var expr = str(expr) if expr in ["0", "0.", "0.0"]: diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 7ab268cfb..fe932afcf 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -28,6 +28,7 @@ from pynestml.codegeneration.printers.ast_printer import ASTPrinter from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter +from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter from pynestml.generated.PyNestMLLexer import PyNestMLLexer from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_block import ASTBlock @@ -984,7 +985,7 @@ def add_declarations_to_state_block(cls, neuron: ASTNeuron, variables: List, ini return neuron @classmethod - def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initial_value: str) -> ASTNeuron: + def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initial_value: str, type_str: str = "real") -> ASTNeuron: """ Adds a single declaration to an arbitrary state block of the neuron. The declared variable is of type real. :param neuron: a neuron @@ -997,7 +998,7 @@ def add_declaration_to_state_block(cls, neuron: ASTNeuron, variable: str, initia tmp = ModelParser.parse_expression(initial_value) vector_variable = ASTUtils.get_vectorized_variable(tmp, neuron.get_scope()) - declaration_string = variable + ' real' + ( + declaration_string = variable + " " + type_str + ( '[' + vector_variable.get_vector_parameter() + ']' if vector_variable is not None and vector_variable.has_vector_parameter() else '') + ' = ' + initial_value ast_declaration = ModelParser.parse_declaration(declaration_string) @@ -1604,12 +1605,13 @@ def update_initial_values_for_odes(cls, neuron: ASTNeuron, solver_dicts: List[di @classmethod def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List[dict], kernels: List[ASTKernel]) -> None: - """ + r""" Add the variables used in kernels from the ode-toolbox result dictionary as ODEs in NESTML AST """ for solver_dict in solver_dicts: if solver_dict is None: continue + for var_name in solver_dict["initial_values"].keys(): if cls.variable_in_kernels(var_name, kernels): # original initial value expressions should have been removed to make place for ode-toolbox results @@ -1622,9 +1624,16 @@ def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List for var_name, expr in solver_dict["initial_values"].items(): # overwrite is allowed because initial values might be repeated between numeric and analytic solver if cls.variable_in_kernels(var_name, kernels): - expr = "0" # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0 + spike_in_port_name = var_name.split("__X__")[1] + spike_in_port_name = spike_in_port_name.split("__d")[0] + spike_in_port = ASTUtils.get_input_port_by_name(neuron.get_input_blocks(), spike_in_port_name) + if spike_in_port: + type_str = NESTMLPrinter().print_data_type(spike_in_port.data_type) + else: + type_str = "real" + expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0 if not cls.declaration_in_state_block(neuron, var_name): - cls.add_declaration_to_state_block(neuron, var_name, expr) + cls.add_declaration_to_state_block(neuron, var_name, expr, type_str) @classmethod def transform_ode_and_kernels_to_json(cls, neuron: ASTNeuron, parameters_blocks: Sequence[ASTBlockWithVariables], @@ -1895,7 +1904,8 @@ def replace_var(_expr, replace_var_name: str, replace_with_var_name: str): for decl in equation_block.get_declarations(): if isinstance(decl, ASTInlineExpression) \ and isinstance(decl.get_expression(), ASTSimpleExpression) \ - and '__X__' in str(decl.get_expression()): + and '__X__' in str(decl.get_expression()) \ + and decl.get_expression().get_variable(): replace_with_var_name = decl.get_expression().get_variable().get_name() neuron.accept(ASTHigherOrderVisitor(lambda x: replace_var( x, decl.get_variable_name(), replace_with_var_name))) diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index 0529911d6..693fcaee7 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -1050,7 +1050,7 @@ def templated_arg_types_inconsistent(cls, function_name, failing_arg_idx, other_ """ message = 'In function \'' + function_name + '\': actual derived type of templated parameter ' + \ str(failing_arg_idx + 1) + ' is \'' + failing_arg_type_str + '\', which is inconsistent with that of parameter(s) ' + \ - ', '.join([str(_ + 1) for _ in other_args_idx]) + ', which have type \'' + other_type_str + '\'' + ', '.join([str(_ + 1) for _ in other_args_idx]) + ', which has/have type \'' + other_type_str + '\'' return MessageCode.TEMPLATED_ARG_TYPES_INCONSISTENT, message @classmethod From c245319b8d00c14bce13aed6ef425813e461f41a Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Wed, 9 Aug 2023 03:34:00 -0700 Subject: [PATCH 2/2] fix type derivation of kernel buffers --- pynestml/utils/ast_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 4e55586da..90461705f 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -1629,6 +1629,9 @@ def create_initial_values_for_kernels(cls, neuron: ASTNeuron, solver_dicts: List spike_in_port = ASTUtils.get_input_port_by_name(neuron.get_input_blocks(), spike_in_port_name) if spike_in_port: type_str = NESTMLPrinter().print_data_type(spike_in_port.data_type) + differential_order: int = len(re.findall("__d", var_name)) + if differential_order: + type_str += "*s**-" + str(differential_order) else: type_str = "real" expr = "0 " + type_str # for kernels, "initial value" returned by ode-toolbox is actually the increment value; the actual initial value is assumed to be 0