From 2863c988f41f2b045c995f81068bfad4cb8770ee Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Thu, 18 Jan 2024 12:16:23 +0100 Subject: [PATCH] Add barebones Artifact abstraction Adds the base interface without supplying a default implementation. After it is checked in, we can think about whether we want to add a base implementation, e.g. via fsspec. --- src/nnbench/core.py | 56 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/src/nnbench/core.py b/src/nnbench/core.py index 21dab86..aa51838 100644 --- a/src/nnbench/core.py +++ b/src/nnbench/core.py @@ -2,26 +2,51 @@ from __future__ import annotations +import os from dataclasses import dataclass, field -from typing import Any, Callable, Iterable +from typing import Any, Callable, Generic, Iterable, TypeVar + +T = TypeVar("T") def NoOp(**kwargs: Any) -> None: pass -# TODO: Should this be frozen (since the setUp and tearDown hooks are empty returns)? -@dataclass(init=False) -class Params: +class Artifact(Generic[T]): """ - A dataclass designed to hold benchmark parameters. This class is not functional - on its own, and needs to be subclassed according to your benchmarking workloads. + A base artifact class for loading (materializing) artifacts from disk or from remote storage. - The main advantage over passing parameters as a dictionary is, of course, - static analysis and type safety for your benchmarking code. + This is a helper to convey which kind of type gets loaded for a benchmark in a type-safe way. + It is most useful when running models on already saved data or models, e.g. when + comparing a newly trained model against a baseline in storage. + + Subclasses need to implement the `Artifact.materialize()` API, telling nnbench how to + load the desired artifact from a path. + + Parameters + ---------- + path: str | os.PathLike[str] + Path to the artifact files. """ - pass + def __init__(self, path: str | os.PathLike[str]) -> None: + # Save the path for later just-in-time materialization. + self.path = path + self._value: T | None = None + + @classmethod + def materialize(cls) -> "Artifact": + """Load the artifact from storage.""" + raise NotImplementedError + + def value(self) -> T: + if self._value is None: + raise ValueError( + f"artifact has not been instantiated yet, " + f"perhaps you forgot to call {self.__class__.__name__}.materialize()?" + ) + return self._value @dataclass(frozen=True) @@ -63,6 +88,19 @@ def __post_init__(self): # TODO: Parse interface using `inspect`, attach to the class +@dataclass(frozen=True, init=False) +class Params: + """ + A dataclass designed to hold benchmark parameters. This class is not functional + on its own, and needs to be subclassed according to your benchmarking workloads. + + The main advantage over passing parameters as a dictionary is, of course, + static analysis and type safety for your benchmarking code. + """ + + pass + + def benchmark( func: Callable[..., Any] | None = None, params: dict[str, Any] | None = None,