diff --git a/src/nnbench/context.py b/src/nnbench/context.py index c02fa91..80ceedb 100644 --- a/src/nnbench/context.py +++ b/src/nnbench/context.py @@ -1,9 +1,8 @@ """Utilities for collecting context key-value pairs as metadata in benchmark runs.""" -import itertools import platform import sys -from collections.abc import Callable, Iterator +from collections.abc import Callable from typing import Any, Literal ContextProvider = Callable[[], dict[str, Any]] @@ -194,219 +193,3 @@ def __call__(self) -> dict[str, Any]: result["memory_unit"] = self.memunit # TODO: Lacks CPU cache info, which requires a solution other than psutil. return {self.key: result} - - -class Context: - def __init__(self, data: dict[str, Any] | None = None) -> None: - self._data: dict[str, Any] = data or {} - - def __contains__(self, key: str) -> bool: - return key in self.keys() - - def __eq__(self, other): - if not isinstance(other, Context): - raise NotImplementedError( - f"cannot compare {type(self)} for equality with type {type(other)}" - ) - return self._data.__eq__(other._data) - - @property - def data(self): - return self._data - - @staticmethod - def _ctx_items(d: dict[str, Any], prefix: str, sep: str) -> Iterator[tuple[str, Any]]: - """ - Iterate over nested dictionary items. Keys are formatted to indicate their nested path. - - Parameters - ---------- - d : dict[str, Any] - Dictionary to iterate over. - prefix : str - Current prefix to prepend to keys, used for recursion to build the full key path. - sep : str - The separator to use between levels of nesting in the key path. - - Yields - ------ - tuple[str, Any] - Iterator over key-value tuples. - """ - for k, v in d.items(): - new_key = prefix + sep + k if prefix else k - if isinstance(v, dict): - yield from Context._ctx_items(d=v, prefix=new_key, sep=sep) - else: - yield new_key, v - - def keys(self, sep: str = ".") -> Iterator[str]: - """ - Keys of the context dictionary, with an optional separator for nested keys. - - Parameters - ---------- - sep : str, optional - Separator to use for nested keys. - - Yields - ------ - str - Iterator over the context dictionary keys. - """ - for k, _ in self._ctx_items(d=self._data, prefix="", sep=sep): - yield k - - def values(self) -> Iterator[Any]: - """ - Values of the context dictionary, including values from nested dictionaries. - - Yields - ------ - Any - Iterator over all values in the context dictionary. - """ - for _, v in self._ctx_items(d=self._data, prefix="", sep=""): - yield v - - def items(self, sep: str = ".") -> Iterator[tuple[str, Any]]: - """ - Items (key-value pairs) of the context dictionary, with an separator for nested keys. - - Parameters - ---------- - sep : str, optional - Separator to use for nested dictionary keys. - - Yields - ------ - tuple[str, Any] - Iterator over the items of the context dictionary. - """ - yield from self._ctx_items(d=self._data, prefix="", sep=sep) - - def add(self, provider: ContextProvider, replace: bool = False) -> None: - """ - Adds data from a provider to the context. - - Parameters - ---------- - provider : ContextProvider - The provider to inject into this context. - replace : bool - Whether to replace existing context values upon key collision. Raises ValueError otherwise. - """ - self.update(Context.make(provider()), replace=replace) - - def update(self, other: "Context", replace: bool = False) -> None: - """ - Updates the context. - - Parameters - ---------- - other : Context - The other context to update this context with. - replace : bool - Whether to replace existing context values upon key collision. Raises ValueError otherwise. - - Raises - ------ - ValueError - If ``other contains top-level keys already present in the context and ``replace=False``. - """ - duplicates = set(self.keys()) & set(other.keys()) - if not replace and duplicates: - dupe, *_ = duplicates - raise ValueError(f"got multiple values for context key {dupe!r}") - self._data.update(other._data) - - @staticmethod - def _flatten_dict(d: dict[str, Any], prefix: str = "", sep: str = ".") -> dict[str, Any]: - """ - Turn a nested dictionary into a flattened dictionary. - - Parameters - ---------- - d : dict[str, Any] - (Possibly) nested dictionary to flatten. - prefix : str - Key prefix to apply at the top-level (nesting level 0). - sep : str - Separator on which to join keys, "." by default. - - Returns - ------- - dict[str, Any] - The flattened dictionary. - """ - - items: list[tuple[str, Any]] = [] - for key, value in d.items(): - new_key = prefix + sep + key if prefix else key - if isinstance(value, dict): - items.extend(Context._flatten_dict(d=value, prefix=new_key, sep=sep).items()) - else: - items.append((new_key, value)) - return dict(items) - - def flatten(self, sep: str = ".") -> dict[str, Any]: - """ - Flatten the context's dictionary, converting nested dictionaries into a single dictionary with keys separated by `sep`. - - Parameters - ---------- - sep : str, optional - The separator used to join nested keys. - - Returns - ------- - dict[str, Any] - The flattened context values as a Python dictionary. - """ - - return self._flatten_dict(self._data, prefix="", sep=sep) - - @staticmethod - def unflatten(d: dict[str, Any], sep: str = ".") -> dict[str, Any]: - """ - Recursively unflatten a dictionary by expanding keys seperated by `sep` into nested dictionaries. - - Parameters - ---------- - d : dict[str, Any] - The dictionary to unflatten. - sep : str, optional - The separator used in the flattened keys. - - Returns - ------- - dict[str, Any] - The unflattened dictionary. - """ - sorted_keys = sorted(d.keys()) - unflattened = {} - for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]): - key_group = list(keys) - if len(key_group) == 1 and sep not in key_group[0]: - unflattened[prefix] = d[prefix] - else: - nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group} - unflattened[prefix] = Context.unflatten(d=nested_dict, sep=sep) - return unflattened - - @classmethod - def make(cls, d: dict[str, Any]) -> "Context": - """ - Create a new Context instance from a given dictionary. - - Parameters - ---------- - d : dict[str, Any] - The initialization dictionary. - - Returns - ------- - Context - The new Context instance. - """ - return cls(data=cls.unflatten(d)) diff --git a/src/nnbench/reporter/duckdb_sql.py b/src/nnbench/reporter/duckdb_sql.py index 12a169a..ce9a565 100644 --- a/src/nnbench/reporter/duckdb_sql.py +++ b/src/nnbench/reporter/duckdb_sql.py @@ -5,8 +5,6 @@ import weakref from pathlib import Path -from nnbench.context import Context - try: import duckdb @@ -122,7 +120,7 @@ def read_sql( results = rel.fetchall() benchmarks = [dict(zip(columns, r)) for r in results] - context = Context() + context = {} for bm in benchmarks: context.update(bm.pop("context", {})) diff --git a/src/nnbench/runner.py b/src/nnbench/runner.py index f23993e..db05e6b 100644 --- a/src/nnbench/runner.py +++ b/src/nnbench/runner.py @@ -14,7 +14,7 @@ from pathlib import Path from typing import Any, get_origin -from nnbench.context import Context, ContextProvider +from nnbench.context import ContextProvider from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State from nnbench.types.memo import is_memo, is_memo_type from nnbench.util import import_file_as_module, ismodule @@ -268,7 +268,7 @@ def run( path_or_module: str | os.PathLike[str], params: dict[str, Any] | Parameters | None = None, tags: tuple[str, ...] = (), - context: Sequence[ContextProvider] | Context = (), + context: Sequence[ContextProvider] = (), ) -> BenchmarkRecord: """ Run a previously collected benchmark workload. @@ -284,7 +284,7 @@ def run( tags: tuple[str, ...] Tags to filter for when collecting benchmarks. Only benchmarks containing either of these tags are collected. - context: Sequence[ContextProvider] | Context + context: Sequence[ContextProvider] Additional context to log with the benchmark in the output JSON record. Useful for obtaining environment information and configuration, like CPU/GPU hardware info, ML model metadata, and more. @@ -302,12 +302,14 @@ def run( family_sizes: dict[str, Any] = collections.defaultdict(int) family_indices: dict[str, Any] = collections.defaultdict(int) - if isinstance(context, Context): - ctx = context - else: - ctx = Context() - for provider in context: - ctx.add(provider) + ctx: dict[str, Any] = {} + for provider in context: + val = provider() + duplicates = set(ctx.keys()) & set(val.keys()) + if duplicates: + dupe, *_ = duplicates + raise ValueError(f"got multiple values for context key {dupe!r}") + ctx.update(val) # if we didn't find any benchmarks, warn and return an empty record. if not self.benchmarks: diff --git a/src/nnbench/types/benchmark.py b/src/nnbench/types/benchmark.py index 566b15f..83c7216 100644 --- a/src/nnbench/types/benchmark.py +++ b/src/nnbench/types/benchmark.py @@ -12,7 +12,6 @@ else: from typing_extensions import Self -from nnbench.context import Context from nnbench.types.interface import Interface @@ -30,7 +29,7 @@ def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None @dataclass(frozen=True) class BenchmarkRecord: - context: Context + context: dict[str, Any] benchmarks: list[dict[str, Any]] def compact( @@ -63,12 +62,7 @@ def compact( for b in self.benchmarks: bc = copy.deepcopy(b) - if mode == "inline": - bc["context"] = self.context.data - elif mode == "flatten": - flat = self.context.flatten(sep=sep) - bc.update(flat) - bc["_contextkeys"] = list(self.context.keys()) + bc["context"] = self.context result.append(bc) return result @@ -90,16 +84,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> Self: The resulting record with the context extracted. """ - dctx: dict[str, Any] = {} + ctx: dict[str, Any] = {} for b in bms: if "context" in b: - dctx = b.pop("context") + ctx = b.pop("context") elif "_contextkeys" in b: ctxkeys = b.pop("_contextkeys") for k in ctxkeys: # This should never throw, save for data corruption. - dctx[k] = b.pop(k) - return cls(context=Context.make(dctx), benchmarks=bms) + ctx[k] = b.pop(k) + return cls(context=ctx, benchmarks=bms) # TODO: Add an expandmany() API for returning a sequence of records for heterogeneous # context data. @@ -151,5 +145,3 @@ class Parameters: The main advantage over passing parameters as a dictionary is, of course, static analysis and type safety for your benchmarking code. """ - - pass diff --git a/src/nnbench/util.py b/src/nnbench/util.py index d4a2a8d..df4c087 100644 --- a/src/nnbench/util.py +++ b/src/nnbench/util.py @@ -2,6 +2,7 @@ import importlib import importlib.util +import itertools import os import sys from importlib.machinery import ModuleSpec @@ -9,6 +10,30 @@ from types import ModuleType +def flatten(d: dict, sep: str = ".", prefix: str = "") -> dict: + d_flat = {} + for k, v in d.items(): + new_key = prefix + sep + k if prefix else k + if isinstance(v, dict): + d_flat.update(flatten(v, sep=sep, prefix=new_key)) + else: + d_flat[k] = v + return d_flat + + +def unflatten(d: dict, sep: str = ".") -> dict: + sorted_keys = sorted(d.keys()) + unflattened = {} + for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]): + key_group = list(keys) + if len(key_group) == 1 and sep not in key_group[0]: + unflattened[prefix] = d[prefix] + else: + nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group} + unflattened[prefix] = unflatten(nested_dict, sep=sep) + return unflattened + + def ismodule(name: str | os.PathLike[str]) -> bool: """Checks if the current interpreter has an available Python module named `name`.""" name = str(name) diff --git a/tests/test_context.py b/tests/test_context.py deleted file mode 100644 index 5449497..0000000 --- a/tests/test_context.py +++ /dev/null @@ -1,108 +0,0 @@ -import pytest - -from nnbench.context import Context, CPUInfo, GitEnvironmentInfo, PythonInfo - - -def test_python_package_info() -> None: - p = PythonInfo("pre-commit", "pyyaml")() - res = p["python"] - - deps = res["dependencies"] - for v in deps.values(): - assert v != "" - - # for a bogus package, it should not fail but produce an empty string. - p = PythonInfo("asdfghjkl")() - res = p["python"] - - deps = res["dependencies"] - for v in deps.values(): - assert v == "" - - -def test_git_info_provider() -> None: - g = GitEnvironmentInfo()() - res = g["git"] - # tag might not be available in a shallow checkout in CI, - # but commit, provider and repo are. - assert res["commit"] != "" - assert res["provider"] == "github.com" - assert res["repository"] == "aai-institute/nnbench" - - -def test_cpu_info_provider() -> None: - c = CPUInfo()() - res = c["cpu"] - - # just check that the most important fields are populated across - # the popular CPU architectures. - assert res["architecture"] != "" - assert res["system"] != "" - assert res["frequency"] >= 0 - assert res["num_cpus"] > 0 - assert res["total_memory"] > 0 - - -def test_flatten_nested_dictionary(): - nested_ctx = Context({"a": 1, "b": {"c": 2, "d": {"e": 3}}}) - flattened = nested_ctx.flatten() - assert flattened == {"a": 1, "b.c": 2, "b.d.e": 3} - - -def test_unflatten_dictionary(): - flat_data = {"a": 1, "b.c": 2, "b.d.e": 3} - unflattened = Context.unflatten(flat_data) - assert unflattened == {"a": 1, "b": {"c": 2, "d": {"e": 3}}} - - -def test_context_keys(): - ctx = Context({"a": 1, "b": {"c": 2}}) - expected_keys = {"a", "b.c"} - assert set(ctx.keys()) == expected_keys - - -def test_context_values(): - ctx = Context({"a": 1, "b": {"c": 2}}) - expected_values = {1, 2} - assert set(ctx.values()) == expected_values - - -def test_context_items(): - ctx = Context({"a": 1, "b": {"c": 2}}) - expected_items = {("a", 1), ("b.c", 2)} - assert set(ctx.items()) == expected_items - - -def test_context_update_with_key_collision(): - ctx = Context({"a": 1, "b": 2}) - with pytest.raises(ValueError, match=r".*multiple values for context.*"): - ctx.update(Context.make({"a": 3, "c": 4})) - - -def test_context_update_duplicate_with_replace(): - ctx = Context({"a": 1, "b": 2}) - ctx.update(Context.make({"a": 3, "c": 4}), replace=True) - expected_dict = {"a": 3, "b": 2, "c": 4} - assert ctx._data == expected_dict - - -def test_update_with_context(): - ctx = Context({"a": 1, "b": 2}) - ctx.update(Context.make({"c": 4})) - expected_dict = {"a": 1, "b": 2, "c": 4} - assert ctx._data == expected_dict - - -def test_add_with_provider(): - ctx = Context({"a": 1, "b": 2}) - ctx.add(lambda: {"c": 4}) - expected_dict = {"a": 1, "b": 2, "c": 4} - assert ctx._data == expected_dict - - -def test_update_with_context_instance(): - ctx1 = Context({"a": 1, "b": {"c": 2}}) - ctx2 = Context({"d": 4}) - ctx1.update(ctx2) - expected_dict = {"a": 1, "b": {"c": 2}, "d": 4} - assert ctx1._data == expected_dict diff --git a/tests/test_fileio.py b/tests/test_fileio.py index a03d4f2..16988d5 100644 --- a/tests/test_fileio.py +++ b/tests/test_fileio.py @@ -4,7 +4,6 @@ import pytest -from nnbench.context import Context from nnbench.reporter.file import FileIO from nnbench.types import BenchmarkRecord @@ -20,7 +19,7 @@ def test_fileio_writes_no_compression_inline( f = FileIO() rec = BenchmarkRecord( - context=Context.make({"a": "b", "s": 1, "b.c": 1.0}), + context={"a": "b", "s": 1, "b.c": 1.0}, benchmarks=[{"name": "foo", "value": 1}, {"name": "bar", "value": 2}], ) file = tmp_path / f"record.{ext}"