diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py deleted file mode 100644 index 4f48aeb..0000000 --- a/tests/test_artifacts.py +++ /dev/null @@ -1,11 +0,0 @@ -from pathlib import Path - -from nnbench.types import FilePathArtifactLoader - - -def test_load_local_file(local_file: Path, tmp_path: Path) -> None: - test_dir = tmp_path / "test_load_dir" - loader = FilePathArtifactLoader(local_file, test_dir) - loaded_path: Path = loader.load() - assert loaded_path.exists() - assert loaded_path.read_text() == "Test content" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 743257f..a3fa1d1 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -30,9 +30,9 @@ def parametrized_benchmark(param: int) -> int: assert len(parametrized_benchmark) == 2 assert has_expected_args(parametrized_benchmark[0].fn, {"param": 1}) - assert parametrized_benchmark[0].fn() == 1 + assert parametrized_benchmark[0].fn(**parametrized_benchmark[0].params) == 1 assert has_expected_args(parametrized_benchmark[1].fn, {"param": 2}) - assert parametrized_benchmark[1].fn() == 2 + assert parametrized_benchmark[1].fn(**parametrized_benchmark[1].params) == 2 def test_parametrize_with_duplicate_parameters(): @@ -50,13 +50,13 @@ def product_benchmark(iter1: int, iter2: str) -> tuple[int, str]: assert len(product_benchmark) == 4 assert has_expected_args(product_benchmark[0].fn, {"iter1": 1, "iter2": "a"}) - assert product_benchmark[0].fn() == (1, "a") + assert product_benchmark[0].fn(**product_benchmark[0].params) == (1, "a") assert has_expected_args(product_benchmark[1].fn, {"iter1": 1, "iter2": "b"}) - assert product_benchmark[1].fn() == (1, "b") + assert product_benchmark[1].fn(**product_benchmark[1].params) == (1, "b") assert has_expected_args(product_benchmark[2].fn, {"iter1": 2, "iter2": "a"}) - assert product_benchmark[2].fn() == (2, "a") + assert product_benchmark[2].fn(**product_benchmark[2].params) == (2, "a") assert has_expected_args(product_benchmark[3].fn, {"iter1": 2, "iter2": "b"}) - assert product_benchmark[3].fn() == (2, "b") + assert product_benchmark[3].fn(**product_benchmark[3].params) == (2, "b") def test_product_with_duplicate_parameters(): diff --git a/tests/test_types.py b/tests/test_types.py index 83ccd8f..ea1bad1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,13 +1,13 @@ import inspect -from nnbench import types +from nnbench.types.types import Interface def test_interface_with_no_arguments(): def fn() -> None: pass - interface = types.Interface.from_callable(fn) + interface = Interface.from_callable(fn, {}) assert interface.names == () assert interface.types == () assert interface.defaults == () @@ -19,7 +19,7 @@ def test_interface_with_multiple_arguments(): def fn(a: int, b, c: str = "hello", d: float = 10.0) -> None: # type: ignore pass - interface = types.Interface.from_callable(fn) + interface = Interface.from_callable(fn, {}) empty = inspect.Parameter.empty assert interface.names == ("a", "b", "c", "d") assert interface.types == (int, empty, str, float)