Skip to content

Commit

Permalink
Refactor Interface class
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmynter committed Jan 24, 2024
1 parent 9b92286 commit 145aec2
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions src/nnbench/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generic, Tuple, TypedDict, TypeVar
from typing import Any, Callable, Generic, TypedDict, TypeVar

T = TypeVar("T")

Expand Down Expand Up @@ -96,32 +96,18 @@ class Benchmark:

@dataclass(frozen=True)
class Interface:
varnames: tuple
vartypes: tuple
varitems: tuple
defaults: dict

@staticmethod
def parse_fn(
fn: Callable[..., Any],
) -> Tuple[
Tuple[str, ...],
Tuple[inspect.Parameter, ...],
Tuple[inspect.Parameter, ...],
Dict[str, Any],
]:
sig = inspect.signature(fn)
varnames = tuple(sig.parameters.keys())
vartypes = tuple(sig.parameters.values())
varitems = tuple(sig.parameters.values())
defaults = {n: p.default for n, p in sig.parameters.items()}

return varnames, vartypes, varitems, defaults
fn: Callable[..., Any]

def __post_init__(self) -> None:
sig = inspect.signature(self.fn)
super().__setattr__("varnames", tuple(sig.parameters.keys()))
super().__setattr__("vartypes", tuple(sig.parameters.values()))
super().__setattr__("varitems", tuple(sig.parameters.values()))
super().__setattr__("defaults", {n: p.default for n, p in sig.parameters.items()})

def __post_init__(self):
if not self.name:
name = self.fn.__name__

super().__setattr__("name", name)
varnames, vartypes, varitems, defaults = self.Interface.parse_fn(self.fn)
self.interface = self.Interface(varnames, vartypes, varitems, defaults)
super().__setattr__("interface", self.Interface(self.fn))

0 comments on commit 145aec2

Please sign in to comment.