Skip to content

Commit

Permalink
Hash cons Apply nodes and Constants
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 7, 2022
1 parent 09f3195 commit 80e412e
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 390 deletions.
191 changes: 106 additions & 85 deletions aesara/graph/basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Core graph classes."""
import abc
import warnings
from collections import deque
from copy import copy
Expand All @@ -26,13 +25,12 @@
Union,
cast,
)
from weakref import WeakKeyDictionary
from weakref import WeakValueDictionary

import numpy as np

from aesara.configdefaults import config
from aesara.graph.utils import (
MetaObject,
MethodNotDefined,
Scratchpad,
TestValueError,
Expand All @@ -53,32 +51,39 @@
_TypeType = TypeVar("_TypeType", bound="Type")
_IdType = TypeVar("_IdType", bound=Hashable)

T = TypeVar("T", bound="Node")
T = TypeVar("T", bound=Union["Apply", "Variable"])
NoParams = object()
NodeAndChildren = Tuple[T, Optional[Iterable[T]]]


class Node(MetaObject):
r"""A `Node` in an Aesara graph.
class UniqueInstanceFactory(type):

Currently, graphs contain two kinds of `Nodes`: `Variable`\s and `Apply`\s.
Edges in the graph are not explicitly represented. Instead each `Node`
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
__instances__: WeakValueDictionary

"""
name: Optional[str]
def __new__(cls, name, bases, dct):
dct["__instances__"] = WeakValueDictionary()
res = super().__new__(cls, name, bases, dct)
return res

def get_parents(self):
"""
Return a list of the parents of this node.
Should return a copy--i.e., modifying the return
value should not modify the graph structure.
def __call__(
cls,
*args,
**kwargs,
):
idp = cls.create_key(*args, **kwargs)

"""
raise NotImplementedError()
if idp not in cls.__instances__:
res = super(UniqueInstanceFactory, cls).__call__(*args, **kwargs)
cls.__instances__[idp] = res
return res

return cls.__instances__[idp]

class Apply(Node, Generic[OpType]):

class Apply(
Generic[OpType],
metaclass=UniqueInstanceFactory,
):
"""A `Node` representing the application of an operation to inputs.
Basically, an `Apply` instance is an object that represents the
Expand Down Expand Up @@ -113,12 +118,19 @@ class Apply(Node, Generic[OpType]):
"""

__slots__ = ("op", "inputs", "outputs", "__weakref__", "tag")

@classmethod
def create_key(cls, op, inputs, outputs):
return (op,) + tuple(inputs)

def __init__(
self,
op: OpType,
inputs: Sequence["Variable"],
outputs: Sequence["Variable"],
):

if not isinstance(inputs, Sequence):
raise TypeError("The inputs of an Apply must be a sequence type")

Expand Down Expand Up @@ -154,6 +166,21 @@ def __init__(
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
)

def __eq__(self, other):
if isinstance(other, type(self)):
if (
self.op == other.op
and self.inputs == other.inputs
# and self.outputs == other.outputs
):
return True
return False

return NotImplemented

def __hash__(self):
return hash((type(self), self.op, tuple(self.inputs), tuple(self.outputs)))

def run_params(self):
"""
Returns the params for the node, or NoParams if no params is set.
Expand All @@ -165,15 +192,19 @@ def run_params(self):
return NoParams

def __getstate__(self):
d = self.__dict__
# ufunc don't pickle/unpickle well
d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)}
if hasattr(self.tag, "ufunc"):
d = copy(self.__dict__)
t = d["tag"]
del t.ufunc
d["tag"] = t
return d

def __setstate__(self, dct):
for k in self.__slots__:
if k in dct:
setattr(self, k, dct[k])

def default_output(self):
"""
Returns the default output for this node.
Expand Down Expand Up @@ -267,6 +298,7 @@ def clone_with_new_inputs(
from aesara.graph.op import HasInnerGraph

assert isinstance(inputs, (list, tuple))

remake_node = False
new_inputs: List["Variable"] = list(inputs)
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
Expand All @@ -280,17 +312,22 @@ def clone_with_new_inputs(
else:
remake_node = True

if remake_node:
new_op = self.op
new_op = self.op

if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
new_op = new_op.clone() # type: ignore
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
new_op = new_op.clone() # type: ignore

if remake_node:
new_node = new_op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
elif new_op == self.op and new_inputs == self.inputs:
new_node = self
else:
new_node = self.clone(clone_inner_graph=clone_inner_graph)
new_node.inputs = new_inputs
new_node = self.__class__(
new_op, new_inputs, [output.clone() for output in self.outputs]
)
new_node.tag = copy(self.tag)

return new_node

def get_parents(self):
Expand All @@ -316,7 +353,7 @@ def params_type(self):
return self.op.params_type


class Variable(Node, Generic[_TypeType, OptionalApplyType]):
class Variable(Generic[_TypeType, OptionalApplyType]):
r"""
A :term:`Variable` is a node in an expression graph that represents a
variable.
Expand Down Expand Up @@ -411,7 +448,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
"""

# __slots__ = ['type', 'owner', 'index', 'name']
__slots__ = ("_owner", "_index", "name", "type", "__weakref__", "tag", "auto_name")
__count__ = count(0)

_owner: OptionalApplyType
Expand Down Expand Up @@ -487,26 +524,17 @@ def __str__(self):
else:
return f"<{self.type}>"

def __repr_test_value__(self):
"""Return a ``repr`` of the test value.
Return a printable representation of the test value. It can be
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
return repr(self.get_test_value())

def __repr__(self, firstPass=True):
"""Return a ``repr`` of the `Variable`.
Return a printable name or description of the Variable. If
``config.print_test_value`` is ``True`` it will also print the test
value, if any.
Return a printable name or description of the `Variable`. If
`aesara.config.print_test_value` is ``True``, it will also print the
test value, if any.
"""
to_print = [str(self)]
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
to_print.append(repr(self.get_test_value()))
except TestValueError:
pass
return "\n".join(to_print)
Expand All @@ -528,26 +556,6 @@ def clone(self):
cp.tag = copy(self.tag)
return cp

def __lt__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __lt__", self.__class__.__name__
)

def __le__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __le__", self.__class__.__name__
)

def __gt__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __gt__", self.__class__.__name__
)

def __ge__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __ge__", self.__class__.__name__
)

def get_parents(self):
if self.owner is not None:
return [self.owner]
Expand Down Expand Up @@ -605,7 +613,7 @@ def eval(self, inputs_to_values=None):
return rval

def __getstate__(self):
d = self.__dict__.copy()
d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)}
d.pop("_fn_cache", None)
if (not config.pickle_test_value) and (hasattr(self.tag, "test_value")):
if not type(config).pickle_test_value.is_default:
Expand All @@ -618,26 +626,24 @@ def __getstate__(self):
d["tag"] = t
return d

def __setstate__(self, dct):
for k in self.__slots__:
if k in dct:
setattr(self, k, dct[k])


class AtomicVariable(Variable[_TypeType, None]):
"""A node type that has no ancestors and should never be considered an input to a graph."""

def __init__(self, type: _TypeType, **kwargs):
super().__init__(type, None, None, **kwargs)

@abc.abstractmethod
def signature(self):
...

def merge_signature(self):
return self.signature()

def equals(self, other):
"""
This does what `__eq__` would normally do, but `Variable` and `Apply`
should always be hashable by `id`.
"""
return isinstance(other, type(self)) and self.signature() == other.signature()
return self == other

@property
def owner(self):
Expand All @@ -661,12 +667,15 @@ def index(self, value):
class NominalVariable(AtomicVariable[_TypeType]):
"""A variable that enables alpha-equivalent comparisons."""

__instances__: WeakKeyDictionary[
__instances__: WeakValueDictionary[
Tuple["Type", Hashable], "NominalVariable"
] = WeakKeyDictionary()
] = WeakValueDictionary()

def __new__(cls, id: _IdType, typ: _TypeType, **kwargs):
if (typ, id) not in cls.__instances__:

idp = (typ, id)

if idp not in cls.__instances__:
var_type = typ.variable_type
type_name = f"Nominal{var_type.__name__}"

Expand All @@ -681,9 +690,9 @@ def _str(self):
)
res: NominalVariable = super().__new__(new_type)

cls.__instances__[(typ, id)] = res
cls.__instances__[idp] = res

return cls.__instances__[(typ, id)]
return cls.__instances__[idp]

def __init__(self, id: _IdType, typ: _TypeType, **kwargs):
self.id = id
Expand All @@ -708,11 +717,11 @@ def __hash__(self):
def __repr__(self):
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"

def signature(self) -> Tuple[_TypeType, _IdType]:
return (self.type, self.id)


class Constant(AtomicVariable[_TypeType]):
class Constant(
AtomicVariable[_TypeType],
metaclass=UniqueInstanceFactory,
):
"""A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant
Expand All @@ -725,19 +734,22 @@ class Constant(AtomicVariable[_TypeType]):
"""

# __slots__ = ['data']
__slots__ = ("type", "data")

@classmethod
def create_key(cls, type, data, *args, **kwargs):
# TODO FIXME: This filters the data twice: once here, and again in
# `cls.__init__`. This might not be a big deal, though.
return (type, type.filter(data))

def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None):
super().__init__(type, name=name)
AtomicVariable.__init__(self, type, name=name)
self.data = type.filter(data)
add_tag_trace(self)

def get_test_value(self):
return self.data

def signature(self):
return (self.type, self.data)

def __str__(self):
if self.name is not None:
return self.name
Expand All @@ -764,6 +776,15 @@ def owner(self, value) -> None:
def value(self):
return self.data

def __hash__(self):
return hash((type(self), self.type, self.data))

def __eq__(self, other):
if isinstance(other, type(self)):
return self.type == other.type and self.data == other.data

return NotImplemented


def walk(
nodes: Iterable[T],
Expand Down
Loading

0 comments on commit 80e412e

Please sign in to comment.