Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface class member to Benchmark class #20

Merged
merged 10 commits into from
Jan 24, 2024
45 changes: 44 additions & 1 deletion src/nnbench/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Useful type interfaces to override/subclass in benchmarking workflows."""
from __future__ import annotations

import inspect
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, TypedDict, TypeVar

T = TypeVar("T")
Variable = tuple[str, type]


class BenchmarkResult(TypedDict):
Expand Down Expand Up @@ -85,17 +87,58 @@ class Benchmark:
A teardown hook run after the benchmark. Must take all members of `params` as inputs.
tags: tuple[str, ...]
Additional tags to attach for bookkeeping and selective filtering during runs.
interface: Interface
Interface of the benchmark function
"""

fn: Callable[..., Any]
name: str | None = field(default=None)
setUp: Callable[..., None] = field(repr=False, default=NoOp)
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
tags: tuple[str, ...] = field(repr=False, default=())
interface: Interface = field(init=False, repr=False)

def __post_init__(self):
if not self.name:
name = self.fn.__name__

super().__setattr__("name", name)
# TODO: Parse interface using `inspect`, attach to the class
super().__setattr__("interface", Interface.from_callable(self.fn))


@dataclass(frozen=True)
class Interface:
"""
Data model representing a function's interface. An instance of this class
is created using the `from_callable` class method.

Parameters:
----------
names : tuple[str, ...]
Names of the function parameters.
types : tuple[type, ...]
Types of the function parameters.
variables : tuple[Variable, ...]
A tuple of tuples, where each inner tuple contains the parameter name and type.
defaults : dict[str, Any]
A dictionary mapping the parameters names to default values.
Only contains parameters with default values.
"""

names: tuple[str, ...]
types: tuple[type, ...]
variables: tuple[Variable, ...]
defaults: dict[str, Any]

@classmethod
def from_callable(cls, fn: Callable) -> Interface:
"""
Creates an Interface instance from a given callable.
"""
sig = inspect.signature(fn)
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple((k, v.annotation) for k, v in sig.parameters.items()),
{n: p.default for n, p in sig.parameters.items()},
)
31 changes: 31 additions & 0 deletions tests/test_benchmark_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import inspect

from nnbench import types


def test_interface_with_no_arguments():
def empty_function() -> None:
pass

interface = types.Interface.from_callable(empty_function)
assert interface.names == ()
assert interface.types == ()
assert interface.variables == ()
assert interface.defaults == {}


def test_interface_with_multiple_arguments():
def complex_function(a: int, b, c: str = "hello", d: float = 10.0) -> None: # type:ignore
pass

interface = types.Interface.from_callable(complex_function)
assert interface.names == ("a", "b", "c", "d")
assert interface.types == (
int,
inspect._empty,
str,
float,
)

assert interface.variables == (("a", int), ("b", inspect._empty), ("c", str), ("d", float))
assert interface.defaults == {"a": inspect._empty, "b": inspect._empty, "c": "hello", "d": 10.0}
Loading