diff --git a/src/nnbench/types.py b/src/nnbench/types.py index f497378..52d51df 100644 --- a/src/nnbench/types.py +++ b/src/nnbench/types.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Generic, TypedDict, TypeVar T = TypeVar("T") +Variable = tuple[str, type] class BenchmarkResult(TypedDict): @@ -104,9 +105,9 @@ def __post_init__(self): @dataclass(frozen=True) class Interface: - varnames: tuple[str, ...] - vartypes: tuple[type, ...] - varitems: tuple[tuple[str, inspect.Parameter], ...] + names: tuple[str, ...] + types: tuple[type, ...] + variables: tuple[Variable, ...] defaults: dict[str, Any] @classmethod @@ -115,6 +116,6 @@ def from_callable(cls, fn: Callable) -> Interface: return cls( tuple(sig.parameters.keys()), tuple(p.annotation for p in sig.parameters.values()), - tuple(sig.parameters.items()), + tuple((k, v.annotation) for k, v in sig.parameters.items()), {n: p.default for n, p in sig.parameters.items()}, ) diff --git a/tests/test_benchmark_cls.py b/tests/test_benchmark_cls.py index b697c7f..07c48a6 100644 --- a/tests/test_benchmark_cls.py +++ b/tests/test_benchmark_cls.py @@ -8,9 +8,9 @@ def empty_function() -> None: pass interface = types.Interface.from_callable(empty_function) - assert interface.varnames == () - assert interface.vartypes == () - assert interface.varitems == () + assert interface.names == () + assert interface.types == () + assert interface.variables == () assert interface.defaults == {} @@ -19,14 +19,13 @@ def complex_function(a: int, b, c: str = "hello", d: float = 10.0) -> None: # t pass interface = types.Interface.from_callable(complex_function) - assert interface.varnames == ("a", "b", "c", "d") - assert interface.vartypes == ( + assert interface.names == ("a", "b", "c", "d") + assert interface.types == ( int, inspect._empty, str, float, ) - varitems = [(param.name, param.annotation) for name, param in interface.varitems] - assert varitems == [("a", int), ("b", inspect._empty), ("c", str), ("d", float)] + assert interface.variables == (("a", int), ("b", inspect._empty), ("c", str), ("d", float)) assert interface.defaults == {"a": inspect._empty, "b": inspect._empty, "c": "hello", "d": 10.0}