Skip to content

Commit

Permalink
Address PR Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmynter committed Jan 24, 2024
1 parent 1a01ce5 commit d480a23
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
9 changes: 5 additions & 4 deletions src/nnbench/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Generic, TypedDict, TypeVar

T = TypeVar("T")
Variable = tuple[str, type]


class BenchmarkResult(TypedDict):
Expand Down Expand Up @@ -104,9 +105,9 @@ def __post_init__(self):

@dataclass(frozen=True)
class Interface:
varnames: tuple[str, ...]
vartypes: tuple[type, ...]
varitems: tuple[tuple[str, inspect.Parameter], ...]
names: tuple[str, ...]
types: tuple[type, ...]
variables: tuple[Variable, ...]
defaults: dict[str, Any]

@classmethod
Expand All @@ -115,6 +116,6 @@ def from_callable(cls, fn: Callable) -> Interface:
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple(sig.parameters.items()),
tuple((k, v.annotation) for k, v in sig.parameters.items()),
{n: p.default for n, p in sig.parameters.items()},
)
13 changes: 6 additions & 7 deletions tests/test_benchmark_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def empty_function() -> None:
pass

interface = types.Interface.from_callable(empty_function)
assert interface.varnames == ()
assert interface.vartypes == ()
assert interface.varitems == ()
assert interface.names == ()
assert interface.types == ()
assert interface.variables == ()
assert interface.defaults == {}


Expand All @@ -19,14 +19,13 @@ def complex_function(a: int, b, c: str = "hello", d: float = 10.0) -> None: # t
pass

interface = types.Interface.from_callable(complex_function)
assert interface.varnames == ("a", "b", "c", "d")
assert interface.vartypes == (
assert interface.names == ("a", "b", "c", "d")
assert interface.types == (
int,
inspect._empty,
str,
float,
)

varitems = [(param.name, param.annotation) for name, param in interface.varitems]
assert varitems == [("a", int), ("b", inspect._empty), ("c", str), ("d", float)]
assert interface.variables == (("a", int), ("b", inspect._empty), ("c", str), ("d", float))
assert interface.defaults == {"a": inspect._empty, "b": inspect._empty, "c": "hello", "d": 10.0}

0 comments on commit d480a23

Please sign in to comment.