From 09bbfde60934b427080f850b726ced23bf7b267a Mon Sep 17 00:00:00 2001 From: Max Mynter Date: Mon, 25 Mar 2024 18:40:17 +0100 Subject: [PATCH 1/4] (wip) Add caching on type-level with the help of two classes. The MemoCache is not intended to be instantiated but holds a global cache. One can interact with this cache via classmethods. The `Memo` class is to be used as a decorator on whatever one wants to cache. It computes a value lazily and caches it within `MemoCache`. You can delete the class to empty the cache. --- src/nnbench/types/types.py | 75 +++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index ba58492..ae2a4e0 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -3,16 +3,10 @@ from __future__ import annotations import copy -import functools import inspect +import threading from dataclasses import dataclass, field -from typing import ( - Any, - Callable, - Generic, - Literal, - TypeVar, -) +from typing import Any, Callable, Generic, Literal, TypeVar from nnbench.context import Context @@ -101,23 +95,60 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord: # context data. +class MemoCache: + _cache: dict[int, Any] = {} + _lock = threading.Lock() + + def __new__(cls) -> "MemoCache": + raise NotImplementedError( + "MemoCache is a module-level singleton utility. It cannot be instantiated." + ) + + @classmethod + def set(cls, key: int, value: Any) -> None: + with cls._lock: + cls._cache[key] = value + + @classmethod + def clear(cls) -> None: + with cls._lock: + cls._cache.clear() + + @classmethod + def drop(cls, key: int) -> None: + with cls._lock: + del cls._cache[key] + + @classmethod + def get(cls, key: int) -> Any: + with cls._lock: + if key in cls._cache: + return cls._cache[key] + else: + return None + + class Memo(Generic[T]): - @functools.cache - # TODO: Swap this out for a local type-wide memo cache. - # Could also be a decorator, should look a bit like this: - # global memo_cache, memo_cache_lock - # _tid = id(self) - # val: T - # with memocache_lock: - # if _tid in memo_cache: - # val = memo_cache[_tid] - # return val - # val = self.compute() - # memo_cache[_tid] = val - # return val - def __call__(self) -> T: + def __init__(self, content: Callable[..., T]) -> None: + self._tid = id(self) + self.content = content + + def compute(self, *args: Any, **kwargs: Any) -> T: raise NotImplementedError + def __call__(self, *args: Any, **kwargs: Any) -> T: + val = MemoCache.get(self._tid) + if val is None: + val = self.compute(*args, **kwargs) + MemoCache.set(self._tid, val) + return val + + def __del__(self) -> None: + try: + MemoCache.drop((self._tid)) + except KeyError: + pass + @dataclass(init=False, frozen=True) class Parameters: From 6320010c5931015a00f9dfb4680aab5be8d16f38 Mon Sep 17 00:00:00 2001 From: Max Mynter Date: Tue, 26 Mar 2024 15:01:00 +0100 Subject: [PATCH 2/4] Cache with global dict --- src/nnbench/__init__.py | 2 +- src/nnbench/types/types.py | 57 ++++++++++---------------------------- 2 files changed, 15 insertions(+), 44 deletions(-) diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index a29f67b..d35141e 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -11,4 +11,4 @@ from .core import benchmark, parametrize, product from .reporter import BenchmarkReporter from .runner import BenchmarkRunner -from .types import Parameters +from .types import Memo, Parameters diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index ae2a4e0..71d25fe 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -13,6 +13,9 @@ T = TypeVar("T") Variable = tuple[str, type, Any] +_memo_cache: dict[int, Any] = {} +_cache_lock = threading.Lock() + def NoOp(**kwargs: Any) -> None: pass @@ -95,59 +98,27 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord: # context data. -class MemoCache: - _cache: dict[int, Any] = {} - _lock = threading.Lock() - - def __new__(cls) -> "MemoCache": - raise NotImplementedError( - "MemoCache is a module-level singleton utility. It cannot be instantiated." - ) - - @classmethod - def set(cls, key: int, value: Any) -> None: - with cls._lock: - cls._cache[key] = value - - @classmethod - def clear(cls) -> None: - with cls._lock: - cls._cache.clear() - - @classmethod - def drop(cls, key: int) -> None: - with cls._lock: - del cls._cache[key] - - @classmethod - def get(cls, key: int) -> Any: - with cls._lock: - if key in cls._cache: - return cls._cache[key] - else: - return None - - class Memo(Generic[T]): def __init__(self, content: Callable[..., T]) -> None: self._tid = id(self) self.content = content - def compute(self, *args: Any, **kwargs: Any) -> T: + def compute(self) -> T: raise NotImplementedError def __call__(self, *args: Any, **kwargs: Any) -> T: - val = MemoCache.get(self._tid) - if val is None: - val = self.compute(*args, **kwargs) - MemoCache.set(self._tid, val) - return val + global _memo_cache, _cache_lock + with _cache_lock: + if self._tid not in _memo_cache: + val = self.compute(*args, **kwargs) + _memo_cache[self._tid] = val + return _memo_cache[self._tid] def __del__(self) -> None: - try: - MemoCache.drop((self._tid)) - except KeyError: - pass + global _memo_cache, _cache_lock + with _cache_lock: + if self._tid in _memo_cache: + del _memo_cache[self._tid] @dataclass(init=False, frozen=True) From bde6455a7fce212ab152bc55764e35d1378a8187 Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Tue, 26 Mar 2024 16:16:38 +0100 Subject: [PATCH 3/4] Add memo caching facility with test Also some very rough memo cache manip APIs. This means all required parts are in place for user memory management with memos. --- src/nnbench/types/types.py | 60 ++++++++++++++++++++++++++++---------- tests/test_memos.py | 28 ++++++++++++++++++ 2 files changed, 73 insertions(+), 15 deletions(-) create mode 100644 tests/test_memos.py diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index 71d25fe..0d659f4 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -3,7 +3,9 @@ from __future__ import annotations import copy +import functools import inspect +import logging import threading from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, TypeVar @@ -16,6 +18,39 @@ _memo_cache: dict[int, Any] = {} _cache_lock = threading.Lock() +logger = logging.getLogger(__name__) + + +def memo_cache_size() -> int: + return len(_memo_cache) + + +def clear_memo_cache() -> None: + with _cache_lock: + _memo_cache.clear() + + +def evict_memo(_id: int) -> Any: + with _cache_lock: + return _memo_cache.pop(_id) + + +def cached_memo(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + _tid = id(self) + with _cache_lock: + if _tid in _memo_cache: + logger.debug(f"Returning memoized value from cache with ID {_tid}") + return _memo_cache[_tid] + logger.debug(f"Computing value on memo with ID {_tid} (cache miss)") + value = fn(self, *args, **kwargs) + with _cache_lock: + _memo_cache[_tid] = value + return value + + return wrapper + def NoOp(**kwargs: Any) -> None: pass @@ -99,26 +134,21 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord: class Memo(Generic[T]): - def __init__(self, content: Callable[..., T]) -> None: - self._tid = id(self) - self.content = content + """Abstract base class for memoized values in benchmark runs.""" - def compute(self) -> T: - raise NotImplementedError + # TODO: Make this better than the decorator application + # -> _Cached metaclass like in fsspec's AbstractFileSystem (maybe vendor with license) - def __call__(self, *args: Any, **kwargs: Any) -> T: - global _memo_cache, _cache_lock - with _cache_lock: - if self._tid not in _memo_cache: - val = self.compute(*args, **kwargs) - _memo_cache[self._tid] = val - return _memo_cache[self._tid] + @cached_memo + def __call__(self) -> T: + raise NotImplementedError def __del__(self) -> None: - global _memo_cache, _cache_lock with _cache_lock: - if self._tid in _memo_cache: - del _memo_cache[self._tid] + sid = id(self) + if sid in _memo_cache: + logger.debug(f"Deleting cached value for memo with ID {sid}") + del _memo_cache[sid] @dataclass(init=False, frozen=True) diff --git a/tests/test_memos.py b/tests/test_memos.py new file mode 100644 index 0000000..fb96d83 --- /dev/null +++ b/tests/test_memos.py @@ -0,0 +1,28 @@ +from typing import Generator + +import pytest + +from nnbench.types.types import Memo, cached_memo, clear_memo_cache, memo_cache_size + + +@pytest.fixture +def clear_memos() -> Generator[None, None, None]: + try: + clear_memo_cache() + yield + finally: + clear_memo_cache() + + +class MyMemo(Memo[int]): + @cached_memo + def __call__(self): + return 0 + + +def test_memo_caching(clear_memos): + m = MyMemo() + assert memo_cache_size() == 0 + m() + assert memo_cache_size() == 1 + m() From bab25d337ea3765f65ad9b6e3458eca4d385219b Mon Sep 17 00:00:00 2001 From: Max Mynter Date: Tue, 26 Mar 2024 17:17:26 +0100 Subject: [PATCH 4/4] Add docstrings --- src/nnbench/types/types.py | 46 ++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index 8b74eaf..bf5e884 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -23,20 +23,58 @@ def memo_cache_size() -> int: + """ + Get the current size of the memo cache. + + Returns + ------- + int + The number of items currently stored in the memo cache. + """ return len(_memo_cache) def clear_memo_cache() -> None: + """ + Clear all items from memo cache in a thread_safe manner. + """ with _cache_lock: _memo_cache.clear() def evict_memo(_id: int) -> Any: + """ + Pop cached item with key `_id` from the memo cache. + + Parameters + ---------- + _id : int + The unique identifier (usually the id assigned by the Python interpreter) of the item to be evicted. + + Returns + ------- + Any + The value that was associated with the removed cache entry. If no item is found with the given `_id`, a KeyError is raised. + """ with _cache_lock: return _memo_cache.pop(_id) -def cached_memo(fn): +def cached_memo(fn: Callable) -> Callable: + """ + Decorator that caches the result of a method call based on the instance ID. + + Parameters + ---------- + fn: Callable + The method to memoize. + + Returns + ------- + Callable + A wrapped version of the method that caches its result. + """ + @functools.wraps(fn) def wrapper(self, *args, **kwargs): _tid = id(self) @@ -154,9 +192,11 @@ class Memo(Generic[T]): @cached_memo def __call__(self) -> T: + """Placeholder to override when subclassing. The call should return the to be cached object.""" raise NotImplementedError def __del__(self) -> None: + """Delete the cached object and clear it from the cache.""" with _cache_lock: sid = id(self) if sid in _memo_cache: @@ -206,9 +246,7 @@ class Benchmark: name: str = field(default="") params: dict[str, Any] = field(default_factory=dict) setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp) - tearDown: Callable[[State, Mapping[str, Any]], None] = field( - repr=False, default=NoOp - ) + tearDown: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp) tags: tuple[str, ...] = field(repr=False, default=()) interface: Interface = field(init=False, repr=False)