Skip to content

Commit

Permalink
Fix print function (#951)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu authored Sep 13, 2023
1 parent f8dbab7 commit 13d08a4
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
13 changes: 12 additions & 1 deletion pynestml/codegeneration/printers/cpp_function_call_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import re

from pynestml.symbols.symbol import SymbolKind

from pynestml.codegeneration.printers.function_call_printer import FunctionCallPrinter
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.symbol_table.scope import Scope
Expand Down Expand Up @@ -191,7 +193,16 @@ def __convert_print_statement_str(self, stmt: str, scope: Scope) -> str:
fun_left = (lambda lhs: self.__convert_print_statement_str(lhs, scope) + ' << ' if lhs else '')
fun_right = (lambda rhs: ' << ' + self.__convert_print_statement_str(rhs, scope) if rhs else '')
ast_var = ASTVariable(var_name, scope=scope)
right = ' ' + ASTUtils.get_unit_name(ast_var) + right # concatenate unit separated by a space with the right part of the string

# set the `_is_numeric` value for the variable so that the variable is printed with the correct origin
symbol = ast_var.get_scope().resolve_to_symbol(var_name, SymbolKind.VARIABLE)
if symbol:
if "_is_numeric" in dir(symbol):
ast_var._is_numeric = symbol._is_numeric

# concatenate unit separated by a space with the right part of the string
if ASTUtils.get_unit_name(ast_var):
right = ' ' + ASTUtils.get_unit_name(ast_var) + right
return fun_left(left) + self._expression_printer.print(ast_var) + fun_right(right)

return '"' + stmt + '"' # format bare string in C++ (add double quotes)
5 changes: 5 additions & 0 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2176,6 +2176,11 @@ def visit_variable(self, node):
else:
node._is_numeric = False

# Set the `_is_numeric` flag in its corresponding symbol
symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE)
if symbol:
symbol._is_numeric = node._is_numeric

visitor = ASTVariableOriginSetterVisitor()
visitor._numeric_state_variables = numeric_state_variable_names
neuron.accept(visitor)
Expand Down
1 change: 1 addition & 0 deletions pynestml/visitors/ast_equations_with_delay_vars_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def visit_simple_expression(self, node: ASTSimpleExpression):
# Get the delay parameter
delay_parameter = ASTUtils.extract_delay_parameter(node.get_function_call())
ast_variable.set_delay_parameter(delay_parameter)
ast_variable.update_scope(node.get_scope())

# Set the variable in the SimpleExpression node
node.set_variable(ast_variable)
Expand Down
1 change: 1 addition & 0 deletions tests/nest_tests/print_statement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_print_statement(self):
matches = [s for s in lines if "print:" in s]
self.assertEqual(matches[0], "print: This is a simple print statement\n")
self.assertEqual(matches[1], "print: Membrane voltage: -0.05 V, threshold: -7e-08 MA Ohm, and V_rel: -50 mV\n")
self.assertEqual(matches[2], "print: Numeric state variable: 0.048731\n")

def tearDown(self) -> None:
if os.path.exists(self.output_path):
Expand Down
15 changes: 14 additions & 1 deletion tests/nest_tests/resources/PrintVariables.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,21 @@ neuron print_variable:
V_m V = -50 mV
V_thr MA*Ohm = -70 mV
V_rel mV = 0 mV
x real = 0

equations:
x' = alpha_x / ((1 + x**n_x) * ms) - beta_x * x

parameters:
alpha_x real = 0.5
beta_x real = 0.5
n_x integer = 2

update:
integrate_odes()

V_rel = V_m
println("print: This is a simple print statement")
print("print: Membrane voltage: {V_m}, threshold: {V_thr}, and V_rel: {V_rel}")
println("print: Membrane voltage: {V_m}, threshold: {V_thr}, and V_rel: {V_rel}")
print("print: Numeric state variable: {x}")

0 comments on commit 13d08a4

Please sign in to comment.