Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle input containing division by zero better #77

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions odetoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8

from .config import Config
from .sympy_helpers import _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _find_in_matrix, _is_zero, _is_sympy_type, SympyPrinter
from .system_of_shapes import SystemOfShapes
from .shapes import MalformedInputException, Shape

Expand Down Expand Up @@ -109,7 +109,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
"""

logging.info("Processing input shapes...")
shapes = []

# first run for grabbing all the variable names. Coefficients might be incorrect.
all_variable_symbols = []
all_parameter_symbols = set()
Expand All @@ -124,6 +124,11 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
assert all([_is_sympy_type(sym) for sym in all_variable_symbols])
logging.info("All known variables: " + str(all_variable_symbols) + ", all parameters used in ODEs: " + str(all_parameter_symbols))

# validate input for forbidden names
for var in set(all_variable_symbols) | all_parameter_symbols:
_check_forbidden_name(var)

# validate parameters
for param in all_parameter_symbols:
if parameters is None:
parameters = dict()
Expand All @@ -135,6 +140,7 @@ def _from_json_to_shapes(indict, parameters=None) -> Tuple[List[Shape], Dict[sym
parameters[param] = None

# second run with the now-known list of variable symbols
shapes = []
for shape_json in indict["dynamics"]:
shape = Shape.from_json(shape_json, all_variable_symbols=all_variable_symbols, parameters=parameters, _debug=True)
shapes.append(shape)
Expand Down Expand Up @@ -210,6 +216,9 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
assert type(k) is sympy.Symbol
parameters[k] = v

_check_forbidden_name(k)


#
# create Shapes and SystemOfShapes
#
Expand Down Expand Up @@ -292,8 +301,12 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
all_shape_symbols = [str(sympy.Symbol(str(shape.symbol) + Config().differential_order_symbol * i)) for i in range(shape.order)]
for sym in all_shape_symbols:
if sym in solver_json["state_variables"]:
solver_json["initial_values"][sym] = str(shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'")))
iv_expr = shape.get_initial_value(sym.replace(Config().differential_order_symbol, "'"))
solver_json["initial_values"][sym] = str(iv_expr)

# validate output for numerical problems
for var in iv_expr.atoms():
_check_numerical_issue(var)

#
# copy the parameter values from the input to the output for convenience; convert into numeric values
Expand All @@ -318,7 +331,20 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
break

if symbol_appears_in_any_expr:
solver_json["parameters"][param_name] = str(sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals).n())
sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals)

# validate output for numerical problems
for var in sympy_expr.atoms():
_check_numerical_issue(var)

# convert to numeric value
sympy_expr = sympy_expr.n()

# validate output for numerical problems
for var in sympy_expr.atoms():
_check_numerical_issue(var)

solver_json["parameters"][param_name] = str(sympy_expr)


#
Expand Down
3 changes: 2 additions & 1 deletion odetoolbox/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class Config:
"sim_time": 100E-3,
"max_step_size": 999.,
"integration_accuracy_abs": 1E-6,
"integration_accuracy_rel": 1E-6
"integration_accuracy_rel": 1E-6,
"forbidden_names": ["oo", "zoo", "nan", "NaN", "__h"]
}

def __getitem__(self, key):
Expand Down
38 changes: 21 additions & 17 deletions odetoolbox/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import annotations

from typing import List, Mapping, Tuple
from typing import List, Tuple

import functools
import logging
Expand All @@ -30,11 +30,9 @@
import sympy.parsing.sympy_parser

from sympy.core.expr import Expr as SympyExpr # works for both sympy 1.4 and 1.8
from sympy.core.numbers import One as SympyOne
from sympy.core.numbers import Zero as SympyZero

from .config import Config
from .sympy_helpers import _custom_simplify_expr, _is_sympy_type, _is_zero
from .sympy_helpers import _check_numerical_issue, _check_forbidden_name, _custom_simplify_expr, _is_constant_term, _is_sympy_type, _is_zero


class MalformedInputException(Exception):
Expand All @@ -44,17 +42,6 @@ class MalformedInputException(Exception):
pass


def is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None):
r"""
:return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
"""
if parameters is None:
parameters = {}
assert all([type(k) is sympy.Symbol for k in parameters.keys()])
return type(term) in [sympy.Float, sympy.Integer, SympyZero, SympyOne] \
or all([sym in parameters.keys() for sym in term.free_symbols])


class Shape:
r"""
This class provides a canonical representation of a shape function independently of the way in which the user specified the shape. It assumes a differential equation of the general form (where bracketed superscript :math:`{}^{(n)}` indicates the :math:`n`-th derivative with respect to time):
Expand All @@ -79,6 +66,7 @@ class Shape:
"Integer": sympy.Integer,
"Float": sympy.Float,
"Function": sympy.Function,
"Mul": sympy.Mul,
"Pow": sympy.Pow,
"power": sympy.Pow,
"exp": sympy.exp,
Expand Down Expand Up @@ -326,6 +314,9 @@ def from_json(cls, indict, all_variable_symbols=None, parameters=None, _debug=Fa
raise MalformedInputException("In defintion of initial value for variable \"" + iv_symbol + "\": differential order (" + str(iv_order) + ") exceeds that of overall equation order (" + str(order) + ")")
if initial_val_specified[iv_order]:
raise MalformedInputException("Initial value for order " + str(iv_order) + " specified more than once")

_check_forbidden_name(iv_symbol)

initial_val_specified[iv_order] = True
initial_values[iv_symbol + iv_order * "'"] = iv_rhs

Expand Down Expand Up @@ -392,13 +383,13 @@ def split_lin_inhom_nonlin(expr, x, parameters=None):
terms = [expr]

for term in terms:
if is_constant_term(term, parameters=parameters):
if _is_constant_term(term, parameters=parameters):
inhom_term += term
else:
# check if the term is linear in any of the symbols in `x`
is_lin = False
for j, sym in enumerate(x):
if is_constant_term(term / sym, parameters=parameters):
if _is_constant_term(term / sym, parameters=parameters):
lin_factors[j] += term / sym
is_lin = True
break
Expand Down Expand Up @@ -584,8 +575,21 @@ def from_ode(cls, symbol: str, definition: str, initial_values: dict, all_variab
order: int = len(initial_values)
all_variable_symbols_dict = {str(el): el for el in all_variable_symbols}
definition = sympy.parsing.sympy_parser.parse_expr(definition.replace("'", Config().differential_order_symbol), global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) # minimal global_dict to make no assumptions (e.g. "beta" could otherwise be recognised as a function instead of as a parameter symbol)

# validate input for forbidden names
_initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, evaluate=False, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()}
for iv_expr in _initial_values.values():
for var in iv_expr.atoms():
_check_forbidden_name(var)

# parse input
initial_values = {k: sympy.parsing.sympy_parser.parse_expr(v, global_dict=Shape._sympy_globals, local_dict=all_variable_symbols_dict) for k, v in initial_values.items()}

# validate input for numerical issues
for iv_expr in initial_values.values():
for var in iv_expr.atoms():
_check_numerical_issue(var)

local_symbols = [symbol + Config().differential_order_symbol * i for i in range(order)]
local_symbols_sympy = [sympy.Symbol(sym_name) for sym_name in local_symbols]
if not symbol in all_variable_symbols:
Expand Down
35 changes: 34 additions & 1 deletion odetoolbox/sympy_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# sympy_printer.py
# sympy_helpers.py
#
# This file is part of the NEST ODE toolbox.
#
Expand All @@ -19,13 +19,46 @@
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

from typing import Mapping

import logging
import sympy
import sys

from .config import Config


class NumericalIssueException(Exception):
r"""Thrown in case of numerical issues, like division by zero."""
pass


def _is_constant_term(term, parameters: Mapping[sympy.Symbol, str] = None) -> bool:
r"""
:return: :python:`True` if and only if this term contains only numerical values and parameters; :python:`False` otherwise.
"""
if parameters is None:
parameters = {}
assert all([type(k) is sympy.Symbol for k in parameters.keys()])
return type(term) in [sympy.Float, sympy.Integer, sympy.core.numbers.Zero, sympy.core.numbers.One] \
or all([sym in parameters.keys() for sym in term.free_symbols])


def _check_numerical_issue(var: str) -> None:
forbidden_vars = ["zoo", "oo", "nan", "NaN"]
stripped_var_name = str(var).strip("'")
if stripped_var_name in forbidden_vars:
raise NumericalIssueException("The variable \"" + stripped_var_name + "\" was found. This indicates a numerical problem while solving the system of ODEs. Please check the input for correctness (such as the presence of divisions by zero).")


def _check_forbidden_name(var: str) -> None:
from .shapes import MalformedInputException

stripped_var_name = str(var).strip("'")
if stripped_var_name in Config().forbidden_names + dir(sympy.core.numbers):
raise MalformedInputException("Variable by name \"" + stripped_var_name + "\" not allowed; this is a reserved name.")


def _is_zero(x):
r"""
Check if a sympy expression is equal to zero.
Expand Down
66 changes: 66 additions & 0 deletions tests/test_malformed_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#
# test_malformed_input.py
#
# This file is part of the NEST ODE toolbox.
#
# Copyright (C) 2017 The NEST Initiative
#
# The NEST ODE toolbox is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation, either version 2 of
# the License, or (at your option) any later version.
#
# The NEST ODE toolbox is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

import pytest

from .context import odetoolbox
from odetoolbox.shapes import MalformedInputException
from odetoolbox.sympy_helpers import NumericalIssueException


class TestMalformedInput:
r"""Test for failure when forbidden names are used in the input."""

@pytest.mark.xfail(strict=True, raises=MalformedInputException)
def test_malformed_input_iv(self):
indict = {"dynamics": [{"expression": "x' = 0",
"initial_value": "zoo"}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=MalformedInputException)
def test_malformed_input_expr(self):
indict = {"dynamics": [{"expression": "x' = 42 * NaN",
"initial_value": "0."}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=MalformedInputException)
def test_malformed_input_sym(self):
indict = {"dynamics": [{"expression": "oo' = 0",
"initial_value": "0."}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

def test_correct_input(self):
indict = {"dynamics": [{"expression": "foo' = 0",
"initial_value": "0."}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=NumericalIssueException)
def test_malformed_input_numerical_iv(self):
indict = {"dynamics": [{"expression": "foo' = 0",
"initial_value": "1/0"}]}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)

@pytest.mark.xfail(strict=True, raises=NumericalIssueException)
def test_malformed_input_numerical_parameter(self):
indict = {"dynamics": [{"expression": "foo' = bar",
"initial_value": "1"}],
"parameters": {"bar": "1/0"}}
result = odetoolbox.analysis(indict, disable_stiffness_check=True)
Loading