Skip to content

Commit

Permalink
Improves error messages for ClassLists (#80)
Browse files Browse the repository at this point in the history
* Improves error messages for ClassLists

* Addresses review comments
  • Loading branch information
DrPaulSharp authored Oct 2, 2024
1 parent 1dbeedc commit dfef174
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 45 deletions.
91 changes: 69 additions & 22 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import collections
import contextlib
import warnings
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from typing import Any, Union

import numpy as np
Expand Down Expand Up @@ -96,8 +96,8 @@ def __setitem__(self, index: int, item: object) -> None:

def _setitem(self, index: int, item: object) -> None:
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._check_classes(self + [item])
self._check_unique_name_fields(self + [item])
self._check_classes([item])
self._check_unique_name_fields([item])
self.data[index] = item

def __delitem__(self, index: int) -> None:
Expand All @@ -118,8 +118,8 @@ def _iadd(self, other: Sequence[object]) -> "ClassList":
other = [other]
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
self._check_classes(other)
self._check_unique_name_fields(other)
super().__iadd__(other)
return self

Expand Down Expand Up @@ -168,8 +168,8 @@ def append(self, obj: object = None, **kwargs) -> None:
if obj:
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self._check_classes([obj])
self._check_unique_name_fields([obj])
self.data.append(obj)
else:
if not hasattr(self, "_class_handle"):
Expand Down Expand Up @@ -215,8 +215,8 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
if obj:
if not hasattr(self, "_class_handle"):
self._class_handle = type(obj)
self._check_classes(self + [obj])
self._check_unique_name_fields(self + [obj])
self._check_classes([obj])
self._check_unique_name_fields([obj])
self.data.insert(index, obj)
else:
if not hasattr(self, "_class_handle"):
Expand Down Expand Up @@ -252,8 +252,8 @@ def extend(self, other: Sequence[object]) -> None:
other = [other]
if not hasattr(self, "_class_handle"):
self._class_handle = self._determine_class_handle(self + other)
self._check_classes(self + other)
self._check_unique_name_fields(self + other)
self._check_classes(other)
self._check_unique_name_fields(other)
self.data.extend(other)

def set_fields(self, index: int, **kwargs) -> None:
Expand Down Expand Up @@ -312,13 +312,14 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
"""
names = [name.lower() for name in self.get_names()]
with contextlib.suppress(KeyError):
if input_args[self.name_field].lower() in names:
name = input_args[self.name_field].lower()
if name in names:
raise ValueError(
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index {names.index(name)} of the ClassList",
)

def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
objects.
Expand All @@ -333,11 +334,49 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
Raised if the input list defines more than one object with the same value of name_field.
"""
names = [getattr(model, self.name_field).lower() for model in input_list if hasattr(model, self.name_field)]
if len(set(names)) != len(names):
raise ValueError(f"Input list contains objects with the same value of the {self.name_field} attribute")
error_list = []
try:
existing_names = [name.lower() for name in self.get_names()]
except AttributeError:
existing_names = []

new_names = [getattr(model, self.name_field).lower() for model in input_list if hasattr(model, self.name_field)]
full_names = existing_names + new_names

# There are duplicate names if this test fails
if len(set(full_names)) != len(full_names):
unique_names = [*dict.fromkeys(new_names)]

for name in unique_names:
existing_indices = [i for i, other_name in enumerate(existing_names) if other_name == name]
new_indices = [i for i, other_name in enumerate(new_names) if other_name == name]
if (len(existing_indices) + len(new_indices)) > 1:
existing_string = ""
new_string = ""
if existing_indices:
existing_list = ", ".join(str(i) for i in existing_indices[:-1])
existing_string = (
f" item{f's {existing_list} and ' if existing_list else ' '}"
f"{existing_indices[-1]} of the existing ClassList"
)
if new_indices:
new_list = ", ".join(str(i) for i in new_indices[:-1])
new_string = (
f" item{f's {new_list} and ' if new_list else ' '}" f"{new_indices[-1]} of the input list"
)
error_list.append(
f" '{name}' is shared between{existing_string}"
f"{', and' if existing_string and new_string else ''}{new_string}"
)

def _check_classes(self, input_list: Iterable[object]) -> None:
if error_list:
newline = "\n"
raise ValueError(
f"The value of the '{self.name_field}' attribute must be unique for each item in the ClassList:\n"
f"{newline.join(error for error in error_list)}"
)

def _check_classes(self, input_list: Sequence[object]) -> None:
"""Raise a ValueError if any object in a list of objects is not of the type specified by self._class_handle.
Parameters
Expand All @@ -348,11 +387,19 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
Raises
------
ValueError
Raised if the input list defines objects of different types.
Raised if the input list contains objects of any type other than that given in self._class_handle.
"""
if not (all(isinstance(element, self._class_handle) for element in input_list)):
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")
error_list = []
for i, element in enumerate(input_list):
if not isinstance(element, self._class_handle):
error_list.append(f" index {i} is of type {type(element).__name__}")
if error_list:
newline = "\n"
raise ValueError(
f"This ClassList only supports elements of type {self._class_handle.__name__}. "
f"In the input list:\n{newline.join(error for error in error_list)}\n"
)

def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, str]:
"""Return the object with the given value of the name_field attribute in the ClassList.
Expand All @@ -379,7 +426,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
@staticmethod
def _determine_class_handle(input_list: Sequence[object]):
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
element which satisfies "issubclass" for all of the other elements.
element which satisfies "issubclass" for all the other elements.
Parameters
----------
Expand Down
104 changes: 81 additions & 23 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_different_classes(self, input_list: Sequence[object]) -> None:
"""If we initialise a ClassList with an input containing multiple classes, we should raise a ValueError."""
with pytest.raises(
ValueError,
match=f"Input list contains elements of type other than '{type(input_list[0]).__name__}'",
match=f"This ClassList only supports elements of type {type(input_list[0]).__name__}. In the input list:\n"
f" index 1 is of type {type(input_list[1]).__name__}\n",
):
ClassList(input_list)

Expand All @@ -134,7 +135,9 @@ def test_identical_name_fields(self, input_list: Sequence[object], name_field: s
"""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the {name_field} attribute",
match=f"The value of the '{name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{getattr(input_list[0], name_field).lower()}'"
f" is shared between items 0 and 1 of the input list",
):
ClassList(input_list, name_field=name_field)

Expand Down Expand Up @@ -194,7 +197,12 @@ def test_setitem(two_name_class_list: ClassList, new_item: InputAttributes, expe
)
def test_setitem_same_name_field(two_name_class_list: ClassList, new_item: InputAttributes) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"):
with pytest.raises(
ValueError,
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{new_item.name.lower()}' is shared between item 1 of the existing ClassList,"
f" and item 0 of the input list",
):
two_name_class_list[0] = new_item


Expand All @@ -206,7 +214,11 @@ def test_setitem_same_name_field(two_name_class_list: ClassList, new_item: Input
)
def test_setitem_different_classes(two_name_class_list: ClassList, new_values: dict[str, Any]) -> None:
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
with pytest.raises(ValueError, match="Input list contains elements of type other than 'InputAttributes'"):
with pytest.raises(
ValueError,
match=f"This ClassList only supports elements of type {two_name_class_list._class_handle.__name__}. "
f"In the input list:\n index 0 is of type {type(new_values).__name__}\n",
):
two_name_class_list[0] = new_values


Expand Down Expand Up @@ -403,7 +415,9 @@ def test_append_object_same_name_field(two_name_class_list: ClassList, new_objec
"""If we append an object with an already-specified name_field value to a ClassList we should raise a ValueError."""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute",
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{new_object.name.lower()}' is shared between item 0 of the existing ClassList, and "
f"item 0 of the input list",
):
two_name_class_list.append(new_object)

Expand All @@ -420,7 +434,7 @@ def test_append_kwargs_same_name_field(two_name_class_list: ClassList, new_value
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index 0 of the ClassList",
):
two_name_class_list.append(**new_values)

Expand Down Expand Up @@ -526,7 +540,9 @@ def test_insert_object_same_name(two_name_class_list: ClassList, new_object: obj
"""If we insert an object with an already-specified name_field value to a ClassList we should raise a ValueError."""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute",
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n '{new_object.name.lower()}' is shared between item 0 of the existing "
f"ClassList, and item 0 of the input list",
):
two_name_class_list.insert(1, new_object)

Expand All @@ -543,7 +559,7 @@ def test_insert_kwargs_same_name(two_name_class_list: ClassList, new_values: dic
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index 0 of the ClassList",
):
two_name_class_list.insert(1, **new_values)

Expand Down Expand Up @@ -702,7 +718,7 @@ def test_set_fields_same_name_field(two_name_class_list: ClassList, new_values:
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{new_values[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"which is already specified at index 1 of the ClassList",
):
two_name_class_list.set_fields(0, **new_values)

Expand Down Expand Up @@ -767,7 +783,7 @@ def test__validate_name_field(two_name_class_list: ClassList, input_dict: dict[s
"input_dict",
[
({"name": "Alice"}),
({"name": "ALICE"}),
({"name": "BOB"}),
({"name": "alice"}),
],
)
Expand All @@ -777,18 +793,18 @@ def test__validate_name_field_not_unique(two_name_class_list: ClassList, input_d
with pytest.raises(
ValueError,
match=f"Input arguments contain the {two_name_class_list.name_field} "
f"'{input_dict[two_name_class_list.name_field]}', "
f"which is already specified in the ClassList",
f"'{input_dict[two_name_class_list.name_field]}', which is already specified at index "
f"{two_name_class_list.index(input_dict['name'].lower())} of the ClassList",
):
two_name_class_list._validate_name_field(input_dict)


@pytest.mark.parametrize(
"input_list",
[
([InputAttributes(name="Alice"), InputAttributes(name="Bob")]),
([InputAttributes(surname="Morgan"), InputAttributes(surname="Terwilliger")]),
([InputAttributes(name="Alice", surname="Morgan"), InputAttributes(surname="Terwilliger")]),
([InputAttributes(name="Eve"), InputAttributes(name="Gareth")]),
([InputAttributes(surname="Polastri"), InputAttributes(surname="Mallory")]),
([InputAttributes(name="Eve", surname="Polastri"), InputAttributes(surname="Mallory")]),
([InputAttributes()]),
([]),
],
Expand All @@ -801,20 +817,59 @@ def test__check_unique_name_fields(two_name_class_list: ClassList, input_list: I


@pytest.mark.parametrize(
"input_list",
["input_list", "error_message"],
[
([InputAttributes(name="Alice"), InputAttributes(name="Alice")]),
([InputAttributes(name="Alice"), InputAttributes(name="ALICE")]),
([InputAttributes(name="Alice"), InputAttributes(name="alice")]),
(
[InputAttributes(name="Alice"), InputAttributes(name="Bob")],
(
" 'alice' is shared between item 0 of the existing ClassList, and item 0 of the input list\n"
" 'bob' is shared between item 1 of the existing ClassList, and item 1 of the input list"
),
),
(
[InputAttributes(name="Alice"), InputAttributes(name="Alice")],
" 'alice' is shared between item 0 of the existing ClassList, and items 0 and 1 of the input list",
),
(
[InputAttributes(name="Alice"), InputAttributes(name="ALICE")],
" 'alice' is shared between item 0 of the existing ClassList, and items 0 and 1 of the input list",
),
(
[InputAttributes(name="Alice"), InputAttributes(name="alice")],
" 'alice' is shared between item 0 of the existing ClassList, and items 0 and 1 of the input list",
),
(
[InputAttributes(name="Eve"), InputAttributes(name="Eve")],
" 'eve' is shared between items 0 and 1 of the input list",
),
(
[
InputAttributes(name="Bob"),
InputAttributes(name="Alice"),
InputAttributes(name="Eve"),
InputAttributes(name="Alice"),
InputAttributes(name="Eve"),
InputAttributes(name="Alice"),
],
(
" 'bob' is shared between item 1 of the existing ClassList, and item 0 of the input list\n"
" 'alice' is shared between item 0 of the existing ClassList,"
" and items 1, 3 and 5 of the input list\n"
" 'eve' is shared between items 2 and 4 of the input list"
),
),
],
)
def test__check_unique_name_fields_not_unique(two_name_class_list: ClassList, input_list: Iterable) -> None:
def test__check_unique_name_fields_not_unique(
two_name_class_list: ClassList, input_list: Sequence, error_message: str
) -> None:
"""We should raise a ValueError if an input list contains multiple objects with (case-insensitive) matching
name_field values defined.
"""
with pytest.raises(
ValueError,
match=f"Input list contains objects with the same value of the " f"{two_name_class_list.name_field} attribute",
match=f"The value of the '{two_name_class_list.name_field}' attribute must be unique for each item in the "
f"ClassList:\n{error_message}",
):
two_name_class_list._check_unique_name_fields(input_list)

Expand All @@ -837,12 +892,15 @@ def test__check_classes(input_list: Iterable) -> None:
([InputAttributes(name="Alice"), dict(name="Bob")]),
],
)
def test__check_classes_different_classes(input_list: Iterable) -> None:
def test__check_classes_different_classes(input_list: Sequence) -> None:
"""We should raise a ValueError if an input list contains objects of different types."""
class_list = ClassList([InputAttributes()])
with pytest.raises(
ValueError,
match=(f"Input list contains elements of type other " f"than '{class_list._class_handle.__name__}'"),
match=(
f"This ClassList only supports elements of type {class_list._class_handle.__name__}. "
f"In the input list:\n index 1 is of type {type(input_list[1]).__name__}"
),
):
class_list._check_classes(input_list)

Expand Down

0 comments on commit dfef174

Please sign in to comment.