diff --git a/src/nnbench/core.py b/src/nnbench/core.py index 6c0cc28..b447e3c 100644 --- a/src/nnbench/core.py +++ b/src/nnbench/core.py @@ -9,7 +9,7 @@ import warnings from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload -from nnbench.types import Benchmark +from nnbench.types import Benchmark, State from nnbench.types.util import is_memo, is_memo_type @@ -167,7 +167,7 @@ def parametrize( def decorator(fn: Callable) -> list[Benchmark]: benchmarks = [] names = set() - for params in parameters: + for idx, params in enumerate(parameters): _check_against_interface(params, fn) name = namegen(fn, **params) @@ -177,8 +177,23 @@ def decorator(fn: Callable) -> list[Benchmark]: f"Perhaps you specified a parameter configuration twice?" ) names.add(name) + state = State( + name=name, + function=fn, + family=fn.__name__, + family_size=len(list(parameters)), + family_index=idx, + ) - bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags) + bm = Benchmark( + fn, + name=name, + params=params, + setUp=setUp, + tearDown=tearDown, + tags=tags, + state=state, + ) benchmarks.append(bm) return benchmarks @@ -224,7 +239,8 @@ def decorator(fn: Callable) -> list[Benchmark]: benchmarks = [] names = set() varnames = iterables.keys() - for values in itertools.product(*iterables.values()): + cartesian_product = list(itertools.product(*iterables.values())) + for idx, values in enumerate(cartesian_product): params = dict(zip(varnames, values)) _check_against_interface(params, fn) @@ -235,8 +251,23 @@ def decorator(fn: Callable) -> list[Benchmark]: f"Perhaps you specified a parameter configuration twice?" ) names.add(name) + state = State( + name=name, + function=fn, + family=fn.__name__, + family_size=len(cartesian_product), + family_index=idx, + ) - bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags) + bm = Benchmark( + fn, + name=name, + params=params, + setUp=setUp, + tearDown=tearDown, + tags=tags, + state=state, + ) benchmarks.append(bm) return benchmarks diff --git a/src/nnbench/types/__init__.py b/src/nnbench/types/__init__.py index c7be11a..90c64a3 100644 --- a/src/nnbench/types/__init__.py +++ b/src/nnbench/types/__init__.py @@ -1 +1 @@ -from .types import Benchmark, BenchmarkRecord, Memo, Parameters +from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index ba58492..8dd71bb 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -14,6 +14,7 @@ TypeVar, ) +from nnbench import __version__ from nnbench.context import Context T = TypeVar("T") @@ -101,6 +102,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord: # context data. +@dataclass(frozen=True) +class State: + name: str + function: Callable[..., Any] + family: str + family_size: int + family_index: int + nnbench_version: str = __version__ + + class Memo(Generic[T]): @functools.cache # TODO: Swap this out for a local type-wide memo cache. @@ -164,11 +175,41 @@ class Benchmark: tearDown: Callable[..., None] = field(repr=False, default=NoOp) tags: tuple[str, ...] = field(repr=False, default=()) interface: Interface = field(init=False, repr=False) + state: State | None = field(default=None) def __post_init__(self): if not self.name: super().__setattr__("name", self.fn.__name__) super().__setattr__("interface", Interface.from_callable(self.fn, self.params)) + if not self.state: + super().__setattr__( + "state", + State( + name=self.name or "", + function=self.fn, + family=self.fn.__name__, + family_size=1, + family_index=0, + ), + ) + + original_setUp = self.setUp + + def wrapped_setUp(*args: Any, **kwargs: Any) -> None: + state = self.state + # TODO: setUp logic + original_setUp(*args, **kwargs) + + super().__setattr__("setUp", wrapped_setUp) + + original_tearDown = self.tearDown + + def wrapped_tearDown(*args: Any, **kwargs: Any) -> None: + state = self.state + # TODO: tearDown logic + original_tearDown(*args, **kwargs) + + super().__setattr__("tearDown", wrapped_tearDown) @dataclass(frozen=True)