Skip to content

Commit

Permalink
Migrate tests to new params assumptions
Browse files Browse the repository at this point in the history
Partial parametrizations are not bound eagerly to the benchmark functions
anymore, which makes it simpler to inject memos and de-memoize variables
just in time for execution.

What is left is validation that a subsequent benchmark of models with
intermittent garbage collection actually reaps each model after the
benchmark is done.
  • Loading branch information
nicholasjng committed Mar 21, 2024
1 parent cb8fa26 commit 98c2909
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 20 deletions.
11 changes: 0 additions & 11 deletions tests/test_artifacts.py

This file was deleted.

12 changes: 6 additions & 6 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def parametrized_benchmark(param: int) -> int:

assert len(parametrized_benchmark) == 2
assert has_expected_args(parametrized_benchmark[0].fn, {"param": 1})
assert parametrized_benchmark[0].fn() == 1
assert parametrized_benchmark[0].fn(**parametrized_benchmark[0].params) == 1
assert has_expected_args(parametrized_benchmark[1].fn, {"param": 2})
assert parametrized_benchmark[1].fn() == 2
assert parametrized_benchmark[1].fn(**parametrized_benchmark[1].params) == 2


def test_parametrize_with_duplicate_parameters():
Expand All @@ -50,13 +50,13 @@ def product_benchmark(iter1: int, iter2: str) -> tuple[int, str]:

assert len(product_benchmark) == 4
assert has_expected_args(product_benchmark[0].fn, {"iter1": 1, "iter2": "a"})
assert product_benchmark[0].fn() == (1, "a")
assert product_benchmark[0].fn(**product_benchmark[0].params) == (1, "a")
assert has_expected_args(product_benchmark[1].fn, {"iter1": 1, "iter2": "b"})
assert product_benchmark[1].fn() == (1, "b")
assert product_benchmark[1].fn(**product_benchmark[1].params) == (1, "b")
assert has_expected_args(product_benchmark[2].fn, {"iter1": 2, "iter2": "a"})
assert product_benchmark[2].fn() == (2, "a")
assert product_benchmark[2].fn(**product_benchmark[2].params) == (2, "a")
assert has_expected_args(product_benchmark[3].fn, {"iter1": 2, "iter2": "b"})
assert product_benchmark[3].fn() == (2, "b")
assert product_benchmark[3].fn(**product_benchmark[3].params) == (2, "b")


def test_product_with_duplicate_parameters():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import inspect

from nnbench import types
from nnbench.types.types import Interface


def test_interface_with_no_arguments():
def fn() -> None:
pass

interface = types.Interface.from_callable(fn)
interface = Interface.from_callable(fn, {})
assert interface.names == ()
assert interface.types == ()
assert interface.defaults == ()
Expand All @@ -19,7 +19,7 @@ def test_interface_with_multiple_arguments():
def fn(a: int, b, c: str = "hello", d: float = 10.0) -> None: # type: ignore
pass

interface = types.Interface.from_callable(fn)
interface = Interface.from_callable(fn, {})
empty = inspect.Parameter.empty
assert interface.names == ("a", "b", "c", "d")
assert interface.types == (int, empty, str, float)
Expand Down

0 comments on commit 98c2909

Please sign in to comment.