Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert a test to use Hypothesis #1507

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy>=1.23.3" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy filelock etuples logical-unification miniKanren cons typing_extensions "setuptools>=48.0.0"
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy>=1.23.3" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy filelock etuples logical-unification miniKanren cons typing_extensions hypothesis "setuptools>=48.0.0"
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge -c numba "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57.0"; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy>=1.23.3" jax jaxlib
pip install --no-deps -e ./
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.hypothesis
*.pkl
_build
__pycache__
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies = [
"cons",
"typing_extensions",
"setuptools >=48.0.0",
"hypothesis",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to add this to the test set up script (i.e. here).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it. I'm not sure how to test it, as I'm a little fuzzy on what this script is used for, but I guess we'll find out when something happens on GitHub. :^)

]
dynamic = ["version"]

Expand Down
101 changes: 101 additions & 0 deletions tests/tensor/test_numpy_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from sys import float_info

import hypothesis.strategies as st
import numpy as np
import pytest
from hypothesis import given, settings

import aesara
import tests.unittest_tools as utt
from aesara.tensor.type import dscalar, zscalar


pytestmark = pytest.mark.filterwarnings("error")

# functions that accept either real or complex inputs
# the second value in each tupple converts an arbitrary float value into a
# value in the function's domain.
COMPLEX_FUNCTIONS = [
(np.tanh, np.log),
(np.cosh, np.log),
(np.sinh, np.log),
(np.arcsinh, lambda x: x),
(np.log, lambda x: x),
(np.log10, lambda x: x),
(np.log1p, lambda x: x / 2),
(np.log2, lambda x: x),
(np.exp, np.log),
(np.expm1, np.log),
(np.exp2, lambda x: np.log2(x) - 1),
(np.sqrt, abs),
]

# functions that only accept real inputs
# the second value in each tupple converts an arbitrary float value into a
# value in the function's domain
REAL_FUNCTIONS = [
(np.deg2rad, lambda x: x),
(np.rad2deg, np.deg2rad),
(np.cos, lambda x: x),
(np.sin, lambda x: x),
(np.tan, lambda x: x),
(np.arctan, lambda x: x),
(np.arcsin, lambda x: x / float_info.max),
(np.arccos, lambda x: x / float_info.max),
]


# tests calling a function with a real value
def do_real_test(fct, value: float):
# set up
x = dscalar("x")
y = fct(x)
f = aesara.function([x], y)

# exercise and verify
utt.assert_allclose(fct(value), f(value))


# tests functions that can be invoked with either real or imaginary inputs
# with real inputs
@pytest.mark.parametrize(
"fct, to_domain",
REAL_FUNCTIONS + COMPLEX_FUNCTIONS,
)
@given(value=st.floats(float_info.min, float_info.max))
@settings(deadline=None)
def test_real(fct, to_domain, value):
do_real_test(fct, to_domain(value))


# arccosh has a domain that is awkward to derive from min/max float
@given(value=st.floats(1, float_info.max))
@settings(deadline=None)
def test_arccosh(value):
do_real_test(np.arccosh, value)


# arctanh has a domain that is awkward to derive from min/max float
@given(value=st.floats(-1 + float_info.epsilon, 1 - float_info.epsilon))
@settings(deadline=None)
def test_arctanh_real(value):
do_real_test(np.arctanh, value)


# tests functions with complex inputs
# cscalar doesn't work for because it loses precision
@pytest.mark.parametrize("fct, to_domain", COMPLEX_FUNCTIONS)
@given(
real_value=st.floats(float_info.min, float_info.max),
imaginary_value=st.floats(float_info.min, float_info.max),
)
@settings(deadline=None)
def test_complex(fct, to_domain, real_value, imaginary_value):
# set up
x = zscalar("x")
y = fct(x)
f = aesara.function([x], y)
value = to_domain(real_value) + 1j * to_domain(imaginary_value)

# exercise and verify
utt.assert_allclose(fct(value), f(value))
36 changes: 0 additions & 36 deletions tests/tensor/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from numpy.testing import assert_array_equal, assert_equal, assert_string_equal

import aesara
import tests.unittest_tools as utt
from aesara.compile.mode import get_default_mode
from aesara.graph.basic import Constant, equal_computations
from aesara.tensor import get_vector_length
Expand All @@ -17,7 +16,6 @@
TensorType,
cscalar,
dmatrix,
dscalar,
dvector,
iscalar,
ivector,
Expand All @@ -39,40 +37,6 @@
pytestmark = pytest.mark.filterwarnings("error")


@pytest.mark.parametrize(
"fct, value",
[
(np.arccos, 0.5),
(np.arccosh, 1.0),
(np.arcsin, 0.5),
(np.arcsinh, 0.5),
(np.arctan, 0.5),
(np.arctanh, 0.5),
(np.cos, 0.5),
(np.cosh, 0.5),
(np.deg2rad, 0.5),
(np.exp, 0.5),
(np.exp2, 0.5),
(np.expm1, 0.5),
(np.log, 0.5),
(np.log10, 0.5),
(np.log1p, 0.5),
(np.log2, 0.5),
(np.rad2deg, 0.5),
(np.sin, 0.5),
(np.sinh, 0.5),
(np.sqrt, 0.5),
(np.tan, 0.5),
(np.tanh, 0.5),
],
)
def test_numpy_method(fct, value):
x = dscalar("x")
y = fct(x)
f = aesara.function([x], y)
utt.assert_allclose(np.nan_to_num(f(value)), np.nan_to_num(fct(value)))


def test_infix_dot_method():
X = dmatrix("X")
y = dvector("y")
Expand Down