Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add global memo cache and integrate with the setUp and teardown injection #130

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter
from .runner import BenchmarkRunner
from .types import Parameters
from .types import Memo, Parameters
104 changes: 91 additions & 13 deletions src/nnbench/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import copy
import functools
import inspect
import logging
import threading
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar
Expand All @@ -14,6 +16,80 @@
T = TypeVar("T")
Variable = tuple[str, type, Any]

_memo_cache: dict[int, Any] = {}
_cache_lock = threading.Lock()

logger = logging.getLogger(__name__)


def memo_cache_size() -> int:
"""
Get the current size of the memo cache.

Returns
-------
int
The number of items currently stored in the memo cache.
"""
return len(_memo_cache)


def clear_memo_cache() -> None:
"""
Clear all items from memo cache in a thread_safe manner.
"""
with _cache_lock:
_memo_cache.clear()


def evict_memo(_id: int) -> Any:
"""
Pop cached item with key `_id` from the memo cache.

Parameters
----------
_id : int
The unique identifier (usually the id assigned by the Python interpreter) of the item to be evicted.

Returns
-------
Any
The value that was associated with the removed cache entry. If no item is found with the given `_id`, a KeyError is raised.
"""
with _cache_lock:
return _memo_cache.pop(_id)


def cached_memo(fn: Callable) -> Callable:
"""
Decorator that caches the result of a method call based on the instance ID.

Parameters
----------
fn: Callable
The method to memoize.

Returns
-------
Callable
A wrapped version of the method that caches its result.
"""

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
_tid = id(self)
with _cache_lock:
if _tid in _memo_cache:
logger.debug(f"Returning memoized value from cache with ID {_tid}")
return _memo_cache[_tid]
logger.debug(f"Computing value on memo with ID {_tid} (cache miss)")
value = fn(self, *args, **kwargs)
with _cache_lock:
_memo_cache[_tid] = value
return value

return wrapper


def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
pass
Expand Down Expand Up @@ -109,22 +185,24 @@ class State:


class Memo(Generic[T]):
@functools.cache
# TODO: Swap this out for a local type-wide memo cache.
# Could also be a decorator, should look a bit like this:
# global memo_cache, memo_cache_lock
# _tid = id(self)
# val: T
# with memocache_lock:
# if _tid in memo_cache:
# val = memo_cache[_tid]
# return val
# val = self.compute()
# memo_cache[_tid] = val
# return val
"""Abstract base class for memoized values in benchmark runs."""

# TODO: Make this better than the decorator application
# -> _Cached metaclass like in fsspec's AbstractFileSystem (maybe vendor with license)

@cached_memo
def __call__(self) -> T:
"""Placeholder to override when subclassing. The call should return the to be cached object."""
raise NotImplementedError

def __del__(self) -> None:
"""Delete the cached object and clear it from the cache."""
with _cache_lock:
sid = id(self)
if sid in _memo_cache:
logger.debug(f"Deleting cached value for memo with ID {sid}")
del _memo_cache[sid]


@dataclass(init=False, frozen=True)
class Parameters:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_memos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Generator

import pytest

from nnbench.types.types import Memo, cached_memo, clear_memo_cache, memo_cache_size


@pytest.fixture
def clear_memos() -> Generator[None, None, None]:
try:
clear_memo_cache()
yield
finally:
clear_memo_cache()


class MyMemo(Memo[int]):
@cached_memo
def __call__(self):
return 0


def test_memo_caching(clear_memos):
m = MyMemo()
assert memo_cache_size() == 0
m()
assert memo_cache_size() == 1
m()