From 13d08a4a10b95ed4218c7aa03545362c6e20d03d Mon Sep 17 00:00:00 2001 From: Pooja Babu <75320801+pnbabu@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:35:21 +0200 Subject: [PATCH] Fix print function (#951) --- .../printers/cpp_function_call_printer.py | 13 ++++++++++++- pynestml/utils/ast_utils.py | 5 +++++ .../ast_equations_with_delay_vars_visitor.py | 1 + tests/nest_tests/print_statement_test.py | 1 + tests/nest_tests/resources/PrintVariables.nestml | 15 ++++++++++++++- 5 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pynestml/codegeneration/printers/cpp_function_call_printer.py b/pynestml/codegeneration/printers/cpp_function_call_printer.py index 126c60529..943c7518f 100644 --- a/pynestml/codegeneration/printers/cpp_function_call_printer.py +++ b/pynestml/codegeneration/printers/cpp_function_call_printer.py @@ -23,6 +23,8 @@ import re +from pynestml.symbols.symbol import SymbolKind + from pynestml.codegeneration.printers.function_call_printer import FunctionCallPrinter from pynestml.meta_model.ast_function_call import ASTFunctionCall from pynestml.symbol_table.scope import Scope @@ -191,7 +193,16 @@ def __convert_print_statement_str(self, stmt: str, scope: Scope) -> str: fun_left = (lambda lhs: self.__convert_print_statement_str(lhs, scope) + ' << ' if lhs else '') fun_right = (lambda rhs: ' << ' + self.__convert_print_statement_str(rhs, scope) if rhs else '') ast_var = ASTVariable(var_name, scope=scope) - right = ' ' + ASTUtils.get_unit_name(ast_var) + right # concatenate unit separated by a space with the right part of the string + + # set the `_is_numeric` value for the variable so that the variable is printed with the correct origin + symbol = ast_var.get_scope().resolve_to_symbol(var_name, SymbolKind.VARIABLE) + if symbol: + if "_is_numeric" in dir(symbol): + ast_var._is_numeric = symbol._is_numeric + + # concatenate unit separated by a space with the right part of the string + if ASTUtils.get_unit_name(ast_var): + right = ' ' + ASTUtils.get_unit_name(ast_var) + right return fun_left(left) + self._expression_printer.print(ast_var) + fun_right(right) return '"' + stmt + '"' # format bare string in C++ (add double quotes) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 9e4718a1c..0070a5832 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -2176,6 +2176,11 @@ def visit_variable(self, node): else: node._is_numeric = False + # Set the `_is_numeric` flag in its corresponding symbol + symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE) + if symbol: + symbol._is_numeric = node._is_numeric + visitor = ASTVariableOriginSetterVisitor() visitor._numeric_state_variables = numeric_state_variable_names neuron.accept(visitor) diff --git a/pynestml/visitors/ast_equations_with_delay_vars_visitor.py b/pynestml/visitors/ast_equations_with_delay_vars_visitor.py index 0d38e3cba..1977271d7 100644 --- a/pynestml/visitors/ast_equations_with_delay_vars_visitor.py +++ b/pynestml/visitors/ast_equations_with_delay_vars_visitor.py @@ -45,6 +45,7 @@ def visit_simple_expression(self, node: ASTSimpleExpression): # Get the delay parameter delay_parameter = ASTUtils.extract_delay_parameter(node.get_function_call()) ast_variable.set_delay_parameter(delay_parameter) + ast_variable.update_scope(node.get_scope()) # Set the variable in the SimpleExpression node node.set_variable(ast_variable) diff --git a/tests/nest_tests/print_statement_test.py b/tests/nest_tests/print_statement_test.py index 75484300b..d4ebee568 100644 --- a/tests/nest_tests/print_statement_test.py +++ b/tests/nest_tests/print_statement_test.py @@ -43,6 +43,7 @@ def test_print_statement(self): matches = [s for s in lines if "print:" in s] self.assertEqual(matches[0], "print: This is a simple print statement\n") self.assertEqual(matches[1], "print: Membrane voltage: -0.05 V, threshold: -7e-08 MA Ohm, and V_rel: -50 mV\n") + self.assertEqual(matches[2], "print: Numeric state variable: 0.048731\n") def tearDown(self) -> None: if os.path.exists(self.output_path): diff --git a/tests/nest_tests/resources/PrintVariables.nestml b/tests/nest_tests/resources/PrintVariables.nestml index e3b25f7ea..31144f6f4 100644 --- a/tests/nest_tests/resources/PrintVariables.nestml +++ b/tests/nest_tests/resources/PrintVariables.nestml @@ -28,8 +28,21 @@ neuron print_variable: V_m V = -50 mV V_thr MA*Ohm = -70 mV V_rel mV = 0 mV + x real = 0 + + equations: + x' = alpha_x / ((1 + x**n_x) * ms) - beta_x * x + + parameters: + alpha_x real = 0.5 + beta_x real = 0.5 + n_x integer = 2 update: + integrate_odes() + V_rel = V_m println("print: This is a simple print statement") - print("print: Membrane voltage: {V_m}, threshold: {V_thr}, and V_rel: {V_rel}") + println("print: Membrane voltage: {V_m}, threshold: {V_thr}, and V_rel: {V_rel}") + print("print: Numeric state variable: {x}") +