Skip to content

Commit

Permalink
Add barebones Artifact abstraction
Browse files Browse the repository at this point in the history
Adds the base interface without supplying a default implementation.
After it is checked in, we can think about whether we want to add a base
implementation, e.g. via fsspec.
  • Loading branch information
nicholasjng committed Jan 18, 2024
1 parent 355f1e2 commit 2863c98
Showing 1 changed file with 47 additions and 9 deletions.
56 changes: 47 additions & 9 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,51 @@

from __future__ import annotations

import os
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable
from typing import Any, Callable, Generic, Iterable, TypeVar

T = TypeVar("T")


def NoOp(**kwargs: Any) -> None:
pass


# TODO: Should this be frozen (since the setUp and tearDown hooks are empty returns)?
@dataclass(init=False)
class Params:
class Artifact(Generic[T]):
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
A base artifact class for loading (materializing) artifacts from disk or from remote storage.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
This is a helper to convey which kind of type gets loaded for a benchmark in a type-safe way.
It is most useful when running models on already saved data or models, e.g. when
comparing a newly trained model against a baseline in storage.
Subclasses need to implement the `Artifact.materialize()` API, telling nnbench how to
load the desired artifact from a path.
Parameters
----------
path: str | os.PathLike[str]
Path to the artifact files.
"""

pass
def __init__(self, path: str | os.PathLike[str]) -> None:
# Save the path for later just-in-time materialization.
self.path = path
self._value: T | None = None

@classmethod
def materialize(cls) -> "Artifact":
"""Load the artifact from storage."""
raise NotImplementedError

def value(self) -> T:
if self._value is None:
raise ValueError(
f"artifact has not been instantiated yet, "
f"perhaps you forgot to call {self.__class__.__name__}.materialize()?"
)
return self._value


@dataclass(frozen=True)
Expand Down Expand Up @@ -63,6 +88,19 @@ def __post_init__(self):
# TODO: Parse interface using `inspect`, attach to the class


@dataclass(frozen=True, init=False)
class Params:
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

pass


def benchmark(
func: Callable[..., Any] | None = None,
params: dict[str, Any] | None = None,
Expand Down

0 comments on commit 2863c98

Please sign in to comment.