Skip to content

Commit

Permalink
Convert __init__s to only accept keyword arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
markusschmaus committed Sep 25, 2022
1 parent 03be691 commit 457ebf4
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 102 deletions.
2 changes: 2 additions & 0 deletions aesara/graph/null_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class NullType(Type):
"""

__props__ = ("why_null",)

def __init__(self, why_null="(no explanation given)"):
self.why_null = why_null

Expand Down
29 changes: 24 additions & 5 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
import inspect
from abc import ABCMeta, abstractmethod
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union, final

from typing_extensions import Protocol, TypeAlias, runtime_checkable

Expand All @@ -11,14 +12,27 @@
D = TypeVar("D")


class NewTypeMeta(type):
# pass
class NewTypeMeta(ABCMeta):
__props__: tuple[str, ...]

def __call__(cls, *args, **kwargs):
raise RuntimeError("Use subtype")
# return super().__call__(*args, **kwargs)

def subtype(cls, *args, **kwargs):
return super().__call__(*args, **kwargs)
kwargs = cls.type_parameters(*args, **kwargs)
return super().__call__(**kwargs)

def type_parameters(cls, *args, **kwargs):
if args:
init_args = tuple(inspect.signature(cls.__init__).parameters.keys())[1:]
if cls.__props__[: len(args)] != init_args[: len(args)]:
raise RuntimeError(
f"{cls.__props__=} doesn't match {init_args=} for {args=}"
)

kwargs |= zip(cls.__props__, args)
return kwargs


class Type(Generic[D], metaclass=NewTypeMeta):
Expand Down Expand Up @@ -293,6 +307,11 @@ def _props_dict(self):
"""
return {a: getattr(self, a) for a in self.__props__}

@final
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def __hash__(self):
return hash((type(self), tuple(getattr(self, a) for a in self.__props__)))

Expand Down
55 changes: 32 additions & 23 deletions aesara/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ class ParamsType(CType):
"""

def __init__(self, **kwargs):
@classmethod
def type_parameters(cls, **kwargs):
params = dict()
if len(kwargs) == 0:
raise ValueError("Cannot create ParamsType from empty data.")

Expand All @@ -366,14 +368,14 @@ def __init__(self, **kwargs):
% (attribute_name, type_name)
)

self.length = len(kwargs)
self.fields = tuple(sorted(kwargs.keys()))
self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name()
params["length"] = len(kwargs)
params["fields"] = tuple(sorted(kwargs.keys()))
params["types"] = tuple(kwargs[field] for field in params["fields"])
params["name"] = cls.generate_struct_name(params)

self.__const_to_enum = {}
self.__alias_to_enum = {}
enum_types = [t for t in self.types if isinstance(t, EnumType)]
params["_const_to_enum"] = {}
params["_alias_to_enum"] = {}
enum_types = [t for t in params["types"] if isinstance(t, EnumType)]
if enum_types:
# We don't want same enum names in different enum types.
if sum(len(t) for t in enum_types) != len(
Expand All @@ -398,35 +400,40 @@ def __init__(self, **kwargs):
)
# We map each enum name to the enum type in which it is defined.
# We will then use this dict to find enum value when looking for enum name in ParamsType object directly.
self.__const_to_enum = {
params["_const_to_enum"] = {
enum_name: enum_type
for enum_type in enum_types
for enum_name in enum_type
}
self.__alias_to_enum = {
params["_alias_to_enum"] = {
alias: enum_type
for enum_type in enum_types
for alias in enum_type.aliases
}

return params

def __setstate__(self, state):
# NB:
# I have overridden __getattr__ to make enum constants available through
# the ParamsType when it contains enum types. To do that, I use some internal
# attributes: self.__const_to_enum and self.__alias_to_enum. These attributes
# attributes: self._const_to_enum and self._alias_to_enum. These attributes
# are normally found by Python without need to call getattr(), but when the
# ParamsType is unpickled, it seems gettatr() may be called at a point before
# __const_to_enum or __alias_to_enum are unpickled, so that gettatr() can't find
# _const_to_enum or _alias_to_enum are unpickled, so that gettatr() can't find
# those attributes, and then loop infinitely.
# For this reason, I must add this trivial implementation of __setstate__()
# to avoid errors when unpickling.
self.__dict__.update(state)

def __getattr__(self, key):
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if key in self.__const_to_enum:
return self.__const_to_enum[key][key]
return super().__getattr__(self, key)
# const_to_enum = super().__getattribute__("_const_to_enum")
if not key.startswith("__"):
const_to_enum = self._const_to_enum
if key in const_to_enum:
return const_to_enum[key][key]
raise AttributeError(f"'{self}' object has no attribute '{key}'")

def __repr__(self):
return "ParamsType<%s>" % ", ".join(
Expand All @@ -446,13 +453,14 @@ def __eq__(self, other):
def __hash__(self):
return hash((type(self),) + self.fields + self.types)

def generate_struct_name(self):
# This method tries to generate an unique name for the current instance.
@staticmethod
def generate_struct_name(params):
# This method tries to generate a unique name for the current instance.
# This name is intended to be used as struct name in C code and as constant
# definition to check if a similar ParamsType has already been created
# (see c_support_code() below).
fields_string = ",".join(self.fields).encode("utf-8")
types_string = ",".join(str(t) for t in self.types).encode("utf-8")
fields_string = ",".join(params["fields"]).encode("utf-8")
types_string = ",".join(str(t) for t in params["types"]).encode("utf-8")
fields_hex = hashlib.sha256(fields_string).hexdigest()
types_hex = hashlib.sha256(types_string).hexdigest()
return f"_Params_{fields_hex}_{types_hex}"
Expand Down Expand Up @@ -510,7 +518,7 @@ def get_enum(self, key):
print(wrapper.TWO)
"""
return self.__const_to_enum[key][key]
return self._const_to_enum[key][key]

def enum_from_alias(self, alias):
"""
Expand Down Expand Up @@ -547,10 +555,11 @@ def enum_from_alias(self, alias):
method to do that.
"""
alias_to_enum = self._alias_to_enum
return (
self.__alias_to_enum[alias].fromalias(alias)
if alias in self.__alias_to_enum
else self.__const_to_enum[alias][alias]
alias_to_enum[alias].fromalias(alias)
if alias in alias_to_enum
else self._const_to_enum[alias][alias]
)

def get_params(self, *objects, **kwargs) -> Params:
Expand Down
Loading

0 comments on commit 457ebf4

Please sign in to comment.