Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-101410: support custom messages for domain errors in the math module #124299

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
40 changes: 40 additions & 0 deletions Lib/test/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2503,6 +2503,46 @@ def test_input_exceptions(self):
self.assertRaises(TypeError, math.atan2, 1.0)
self.assertRaises(TypeError, math.atan2, 1.0, 2.0, 3.0)

def test_exception_messages(self):
x = -1.1
with self.assertRaisesRegex(ValueError,
f"expected a nonnegative input, got {x}"):
math.sqrt(x)
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {x}"):
math.log(x)
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {x}"):
math.log(123, x)
skirpichev marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {x}"):
math.log(x, 123)
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {x}"):
math.log2(x)
skirpichev marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {x}"):
math.log10(x)
x = decimal.Decimal('-1.1')
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {x}"):
math.log(x)
x = fractions.Fraction(1, 10**400)
with self.assertRaisesRegex(ValueError,
f"expected a positive input, got {float(x)}"):
math.log(x)
x = -123
with self.assertRaisesRegex(ValueError,
f"expected a positive input"):
math.log(x)
with self.assertRaisesRegex(ValueError,
f"expected a float or nonnegative integer, got {x}"):
math.gamma(x)
x = 1.0
with self.assertRaisesRegex(ValueError,
f"expected a number between -1 and 1, got {x}"):
math.atanh(x)

# Custom assertions.

def assertIsNaN(self, value):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Support custom messages for domain errors in the :mod:`math` module
(:func:`math.sqrt`, :func:`math.log` and :func:`math.atanh` were modified as
examples). Patch by Charlie Zhao and Sergey B Kirpichev.
93 changes: 65 additions & 28 deletions Modules/mathmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -851,12 +851,15 @@ PyDoc_STRVAR(math_lcm_doc,
* true (1), but may return false (0) without setting up an exception.
*/
static int
is_error(double x)
is_error(double x, int raise_edom)
{
int result = 1; /* presumption of guilt */
assert(errno); /* non-zero errno is a precondition for calling */
if (errno == EDOM)
PyErr_SetString(PyExc_ValueError, "math domain error");
if (errno == EDOM) {
if (raise_edom) {
PyErr_SetString(PyExc_ValueError, "math domain error");
}
}

else if (errno == ERANGE) {
/* ANSI C generally requires libm functions to set ERANGE
Expand Down Expand Up @@ -921,50 +924,69 @@ is_error(double x)
*/

static PyObject *
math_1(PyObject *arg, double (*func) (double), int can_overflow)
math_1(PyObject *arg, double (*func) (double), int can_overflow,
const char *err_msg)
{
double x, r;
x = PyFloat_AsDouble(arg);
if (x == -1.0 && PyErr_Occurred())
return NULL;
errno = 0;
r = (*func)(x);
if (isnan(r) && !isnan(x)) {
PyErr_SetString(PyExc_ValueError,
"math domain error"); /* invalid arg */
return NULL;
}
if (isnan(r) && !isnan(x))
goto domain_err; /* domain error */
if (isinf(r) && isfinite(x)) {
if (can_overflow)
PyErr_SetString(PyExc_OverflowError,
"math range error"); /* overflow */
else
PyErr_SetString(PyExc_ValueError,
"math domain error"); /* singularity */
goto domain_err; /* singularity */
return NULL;
}
if (isfinite(r) && errno && is_error(r))
if (isfinite(r) && errno && is_error(r, 1))
/* this branch unnecessary on most platforms */
return NULL;

return PyFloat_FromDouble(r);

domain_err:
if (err_msg) {
char *buf = PyOS_double_to_string(x, 'r', 0, Py_DTSF_ADD_DOT_0, NULL);
if (buf) {
PyErr_Format(PyExc_ValueError, err_msg, buf);
PyMem_Free(buf);
}
}
else {
PyErr_SetString(PyExc_ValueError, "math domain error");
}
return NULL;
}

/* variant of math_1, to be used when the function being wrapped is known to
set errno properly (that is, errno = EDOM for invalid or divide-by-zero,
errno = ERANGE for overflow). */

static PyObject *
math_1a(PyObject *arg, double (*func) (double))
math_1a(PyObject *arg, double (*func) (double), const char *err_msg)
{
double x, r;
x = PyFloat_AsDouble(arg);
if (x == -1.0 && PyErr_Occurred())
return NULL;
errno = 0;
r = (*func)(x);
if (errno && is_error(r))
if (errno && is_error(r, err_msg ? 0 : 1)) {
if (err_msg && errno == EDOM) {
skirpichev marked this conversation as resolved.
Show resolved Hide resolved
skirpichev marked this conversation as resolved.
Show resolved Hide resolved
assert(!PyErr_Occurred()); /* exception is not set by is_error() */
char *buf = PyOS_double_to_string(x, 'r', 0, Py_DTSF_ADD_DOT_0, NULL);
if (buf) {
PyErr_Format(PyExc_ValueError, err_msg, buf);
PyMem_Free(buf);
}
}
return NULL;
}
return PyFloat_FromDouble(r);
}

Expand Down Expand Up @@ -1024,21 +1046,33 @@ math_2(PyObject *const *args, Py_ssize_t nargs,
else
errno = 0;
}
if (errno && is_error(r))
if (errno && is_error(r, 1))
return NULL;
else
return PyFloat_FromDouble(r);
}

#define FUNC1(funcname, func, can_overflow, docstring) \
static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
return math_1(args, func, can_overflow); \
return math_1(args, func, can_overflow, NULL); \
}\
skirpichev marked this conversation as resolved.
Show resolved Hide resolved
PyDoc_STRVAR(math_##funcname##_doc, docstring);
skirpichev marked this conversation as resolved.
Show resolved Hide resolved

#define FUNC1D(funcname, func, can_overflow, docstring, err_msg) \
skirpichev marked this conversation as resolved.
Show resolved Hide resolved
static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
return math_1(args, func, can_overflow, err_msg); \
}\
PyDoc_STRVAR(math_##funcname##_doc, docstring);

#define FUNC1A(funcname, func, docstring) \
static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
return math_1a(args, func); \
return math_1a(args, func, NULL); \
}\
PyDoc_STRVAR(math_##funcname##_doc, docstring);

#define FUNC1AD(funcname, func, docstring, err_msg) \
static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
return math_1a(args, func, err_msg); \
}\
PyDoc_STRVAR(math_##funcname##_doc, docstring);

Expand Down Expand Up @@ -1070,9 +1104,10 @@ FUNC2(atan2, atan2,
"atan2($module, y, x, /)\n--\n\n"
"Return the arc tangent (measured in radians) of y/x.\n\n"
"Unlike atan(y/x), the signs of both x and y are considered.")
FUNC1(atanh, atanh, 0,
FUNC1D(atanh, atanh, 0,
"atanh($module, x, /)\n--\n\n"
"Return the inverse hyperbolic tangent of x.")
"Return the inverse hyperbolic tangent of x.",
"expected a number between -1 and 1, got %s")
FUNC1(cbrt, cbrt, 0,
"cbrt($module, x, /)\n--\n\n"
"Return the cube root of x.")
Expand Down Expand Up @@ -1183,9 +1218,10 @@ math_floor(PyObject *module, PyObject *number)
return PyLong_FromDouble(floor(x));
}

FUNC1A(gamma, m_tgamma,
FUNC1AD(gamma, m_tgamma,
"gamma($module, x, /)\n--\n\n"
"Gamma function at x.")
"Gamma function at x.",
"expected a float or nonnegative integer, got %s")
FUNC1A(lgamma, m_lgamma,
"lgamma($module, x, /)\n--\n\n"
"Natural logarithm of absolute value of Gamma function at x.")
Expand All @@ -1205,9 +1241,10 @@ FUNC1(sin, sin, 0,
FUNC1(sinh, sinh, 1,
"sinh($module, x, /)\n--\n\n"
"Return the hyperbolic sine of x.")
FUNC1(sqrt, sqrt, 0,
FUNC1D(sqrt, sqrt, 0,
"sqrt($module, x, /)\n--\n\n"
"Return the square root of x.")
"Return the square root of x.",
"expected a nonnegative input, got %s")
FUNC1(tan, tan, 0,
"tan($module, x, /)\n--\n\n"
"Return the tangent of x (measured in radians).")
Expand Down Expand Up @@ -2134,7 +2171,7 @@ math_ldexp_impl(PyObject *module, double x, PyObject *i)
errno = ERANGE;
}

if (errno && is_error(r))
if (errno && is_error(r, 1))
return NULL;
return PyFloat_FromDouble(r);
}
Expand Down Expand Up @@ -2189,7 +2226,7 @@ loghelper(PyObject* arg, double (*func)(double))
/* Negative or zero inputs give a ValueError. */
if (!_PyLong_IsPositive((PyLongObject *)arg)) {
PyErr_SetString(PyExc_ValueError,
"math domain error");
"expected a positive input");
return NULL;
}

Expand All @@ -2213,7 +2250,7 @@ loghelper(PyObject* arg, double (*func)(double))
}

/* Else let libm handle it by itself. */
return math_1(arg, func, 0);
return math_1(arg, func, 0, "expected a positive input, got %s");
}


Expand Down Expand Up @@ -2362,7 +2399,7 @@ math_fmod_impl(PyObject *module, double x, double y)
else
errno = 0;
}
if (errno && is_error(r))
if (errno && is_error(r, 1))
return NULL;
else
return PyFloat_FromDouble(r);
Expand Down Expand Up @@ -2999,7 +3036,7 @@ math_pow_impl(PyObject *module, double x, double y)
}
}

if (errno && is_error(r))
if (errno && is_error(r, 1))
return NULL;
else
return PyFloat_FromDouble(r);
Expand Down
Loading