Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix power operator and elif functions in NESTML #950

Merged
merged 6 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __check_return_recursively(cls, type_symbol=None, stmts=None, ret_defined=Fa
stmt.get_if_stmt().get_if_clause().get_block().get_stmts(),
ret_defined)
for else_ifs in stmt.get_if_stmt().get_elif_clauses():
cls.__check_return_recursively(type_symbol, else_ifs.get_block().get_stmt(), ret_defined)
cls.__check_return_recursively(type_symbol, else_ifs.get_block().get_stmts(), ret_defined)
if stmt.get_if_stmt().has_else_clause():
cls.__check_return_recursively(type_symbol,
stmt.get_if_stmt().get_else_clause().get_block().get_stmts(),
Expand Down
59 changes: 45 additions & 14 deletions pynestml/visitors/ast_power_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
"""
rhs : <assoc=right> left=rhs powOp='**' right=rhs
"""
from pynestml.codegeneration.nest_unit_converter import NESTUnitConverter
from pynestml.meta_model.ast_expression import ASTExpression
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.symbol import SymbolKind
from pynestml.symbols.unit_type_symbol import UnitTypeSymbol
from pynestml.utils.either import Either
from pynestml.utils.error_strings import ErrorStrings
Expand Down Expand Up @@ -70,21 +73,49 @@ def calculate_numeric_value(self, expr):
:return: an Either object
:rtype: Either
"""
# TODO write tests for this by PTraeder
if isinstance(expr, ASTExpression) and expr.is_encapsulated:
return self.calculate_numeric_value(expr.get_expression())
elif isinstance(expr, ASTSimpleExpression) and expr.get_numeric_literal() is not None:
if isinstance(expr.get_numeric_literal(), int) \
or isinstance(expr.get_numeric_literal(), float):
literal = expr.get_numeric_literal()
return Either.value(literal)
else:
if isinstance(expr, ASTExpression):
if expr.is_encapsulated:
return self.calculate_numeric_value(expr.get_expression())
if expr.is_unary_operator() and expr.get_unary_operator().is_unary_minus:
term = self.calculate_numeric_value(expr.get_expression())
if term.is_error():
return term
return Either.value(-term.get_value())
if expr.get_binary_operator() is not None:
op = expr.get_binary_operator()
lhs = expr.get_lhs()
rhs = expr.get_rhs()
if op.is_plus_op:
return Either.value(self.calculate_numeric_value(lhs).get_value() + self.calculate_numeric_value(rhs).get_value())

if op.is_minus_op:
return Either.value(self.calculate_numeric_value(lhs).get_value() - self.calculate_numeric_value(rhs).get_value())

if op.is_times_op:
return Either.value(self.calculate_numeric_value(lhs).get_value() * self.calculate_numeric_value(rhs).get_value())

if op.is_div_op:
return Either.value(self.calculate_numeric_value(lhs).get_value() / self.calculate_numeric_value(rhs).get_value())

if op.is_modulo_op:
return Either.value(self.calculate_numeric_value(lhs).get_value() % self.calculate_numeric_value(rhs).get_value())
return self.calculate_numeric_value(expr)
if isinstance(expr, ASTSimpleExpression):
if expr.get_numeric_literal() is not None:
if isinstance(expr.get_numeric_literal(), int) \
or isinstance(expr.get_numeric_literal(), float):
literal = expr.get_numeric_literal()
return Either.value(literal)
error_message = ErrorStrings.message_unit_base(self, expr.get_source_position())
return Either.error(error_message)
elif expr.is_unary_operator() and expr.get_unary_operator().is_unary_minus:
term = self.calculate_numeric_value(expr.get_expression())
if term.is_error():
return term
return Either.value(-term.get_value())

# expr is a variable
variable = expr.get_variable()
symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)
if symbol is None:
if PredefinedUnits.is_unit(variable.get_complete_name()):
return Either.value(NESTUnitConverter.get_factor(PredefinedUnits.get_unit(variable.get_complete_name()).get_unit()))
return self.calculate_numeric_value(symbol.get_declaring_expression())

error_message = ErrorStrings.message_non_constant_exponent(self, expr.get_source_position())
return Either.error(error_message)
2 changes: 1 addition & 1 deletion tests/cocos_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_invalid_function_unique_and_defined(self):
os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
'CoCoFunctionNotUnique.nestml'))
self.assertEqual(
len(Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], LoggingLevel.ERROR)), 4)
len(Logger.get_all_messages_of_level_and_or_node(model.get_neuron_list()[0], LoggingLevel.ERROR)), 5)

def test_valid_function_unique_and_defined(self):
Logger.set_logging_level(LoggingLevel.INFO)
Expand Down
2 changes: 2 additions & 0 deletions tests/invalid/CoCoFunctionNotUnique.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,7 @@ neuron CoCoFunctionNotUnique:
test real = 1
if True == True:
return True
elif Tau_a == 2:
test = Tau_a # here no return statement should be detected
else:
test = test # here no return statement should be detected
11 changes: 11 additions & 0 deletions tests/resources/ExpressionCollection.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ neuron ExpressionCollection:
g_GABAA' nS/ms = 0.
g_GABAB' nS/ms = 0.

# power operator
bar real = 1.5

parameters:
#neuron aeif_cond_alpha_neuron
testA ms**2 = 1
Expand All @@ -84,6 +87,10 @@ neuron ExpressionCollection:

beta real = 1. # check for conflict with sympy built-in functions like beta()

# for power operator
expo1 integer = 3
expo2 mmol/pA = 2 mmol/pA

#hh_cond_exp_traub_neuron
test70 nS = 0nS # Inhibitory synaptic conductance

Expand Down Expand Up @@ -481,6 +488,10 @@ neuron ExpressionCollection:
test78 real = alpha_h_init / ( alpha_h_init + beta_h_init )
test79 real = alpha_n_init / ( alpha_n_init + beta_n_init )

bar = beta ** expo1
bar = beta ** expo2
bar = beta ** (expo1 + expo2)

integrate_odes()
# sending spikes: crossing 0 mV, pseudo-refractoriness and local maximum...
if r > 0: # is refractory?
Expand Down
15 changes: 15 additions & 0 deletions tests/resources/ExpressionTypeTest.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ neuron expressionType_test:
timesquared s**2 = 1.44 s**2
velocity m/s = 23.5 m/s

# power operator
bar1 mol**2 = 1.5 mol**2
bar2 mol**3 = -22 mol**3
bar3 mol**5 = 3.99 mol**5
beta mol = 42 mol

parameters:
expo1 integer = 3
expo2 mmol/pA = 2 mmol/pA

function foo(_mass kg, _dist m, _time s) N :
return _mass*_dist/(_time**2)

Expand All @@ -72,3 +82,8 @@ neuron expressionType_test:
force = foo(mass,distance,time)
force = 1 kg * distance/(time**2)
force = kg*m/(s**2)

# power operators
bar1 = beta ** expo2
bar2 = beta ** expo1
bar3 = beta ** (expo1 + expo2)
9 changes: 6 additions & 3 deletions tests/valid/CoCoFunctionNotUnique.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ neuron CoCoFunctionNotUnique:

function deltaNoReturn(Tau_a ms,Tau_b ms) boolean:
test real = 1
if True == True:
return True
if Tau_a == 1:
test = True
elif Tau_b == 2:
test = True
else:
return False
test = False
return test
Loading