From 3e9665ca2705b1dd63761b455be9b7fee091f701 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 29 Aug 2022 15:04:34 -0500 Subject: [PATCH] Hash-cons Apply, Constant and change node input replacement semantics --- aesara/compile/debugmode.py | 117 +++++++------- aesara/graph/basic.py | 235 +++++++++++++++++----------- aesara/graph/destroyhandler.py | 229 +++++++++++++-------------- aesara/graph/features.py | 116 +++++++++----- aesara/graph/fg.py | 154 +++++++++++++++--- aesara/graph/rewriting/basic.py | 76 +++++---- aesara/link/c/basic.py | 4 +- aesara/link/c/params_type.py | 8 +- aesara/link/c/type.py | 10 +- aesara/sparse/basic.py | 45 +++--- aesara/tensor/basic.py | 2 + aesara/tensor/rewriting/shape.py | 108 +++++++------ aesara/tensor/type.py | 11 +- aesara/tensor/type_other.py | 49 +++--- aesara/tensor/var.py | 144 +---------------- tests/graph/rewriting/test_basic.py | 15 +- tests/graph/test_basic.py | 209 ++++++++++++++++--------- tests/graph/test_destroyhandler.py | 59 +++++-- tests/graph/test_fg.py | 208 +++++++++++++++++------- 19 files changed, 1034 insertions(+), 765 deletions(-) diff --git a/aesara/compile/debugmode.py b/aesara/compile/debugmode.py index c3c116b091..41e45d2eec 100644 --- a/aesara/compile/debugmode.py +++ b/aesara/compile/debugmode.py @@ -14,10 +14,11 @@ from itertools import chain from itertools import product as itertools_product from logging import Logger -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union from warnings import warn import numpy as np +from typing_extensions import Literal import aesara from aesara.compile.function.types import ( @@ -42,7 +43,9 @@ from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function -__docformat__ = "restructuredtext en" +if TYPE_CHECKING: + from aesara.graph.basic import Apply + _logger: Logger = logging.getLogger("aesara.compile.debugmode") _logger.addFilter(NoDuplicateOptWarningFilter()) @@ -1109,43 +1112,32 @@ class _FunctionGraphEvent: """ - kind = "" - """ - One of 'import', 'change', 'prune'. - - """ - - node = None - """ - Either 'output' or an Apply instance. - - """ - - op = None - """Either 'output' or an Op instance""" + kind: Literal["import", "change", "prune"] + old_node: Optional[Union[Literal["output"], "Apply"]] + new_node: Optional[Union[Literal["output"], "Apply"]] + op: Optional[Union[Literal["output"], Op]] + idx: Optional[int] + reason: Optional[str] - idx = None - """ - Change events involve an position index of the input variable. - - """ - - reason = None - """ - Change events sometimes have a reason. - - """ - - def __init__(self, kind, node, idx=None, reason=None): + def __init__( + self, + kind: Literal["import", "change", "prune"], + old_node: Union[Literal["output"], "Apply"], + new_node: Union[Literal["output"], "Apply"] = None, + idx: Optional[int] = None, + reason: Optional[str] = None, + ): self.kind = kind - if node == "output": - self.node = "output" + if old_node == "output": + self.old_node = "output" + self.new_node = "output" self.op = "output" else: - self.node = node - self.op = node.op + self.old_node = old_node + self.new_node = new_node + self.op = old_node.op self.idx = idx - self.reason = str(reason) + self.reason = str(reason) if reason else None def __str__(self): if self.kind == "change": @@ -1219,21 +1211,21 @@ def on_attach(self, fgraph): self.replaced_by = {} self.event_list = [] for node in fgraph.toposort(): - self.on_import(fgraph, node, "on_attach") + self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): assert fgraph is self.fgraph self.fgraph = None def on_prune(self, fgraph, node, reason): - self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason))) + self.event_list.append(_FunctionGraphEvent("prune", node, reason=reason)) assert node in self.active_nodes assert node not in self.inactive_nodes self.active_nodes.remove(node) self.inactive_nodes.add(node) def on_import(self, fgraph, node, reason): - self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason))) + self.event_list.append(_FunctionGraphEvent("import", node, reason=reason)) assert node not in self.active_nodes self.active_nodes.add(node) @@ -1253,18 +1245,23 @@ def on_import(self, fgraph, node, reason): self.reasons.setdefault(r, []) self.replaced_by.setdefault(r, []) - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): reason = str(reason) self.event_list.append( - _FunctionGraphEvent("change", node, reason=reason, idx=i) + _FunctionGraphEvent("change", old_node, new_node, idx=i, reason=reason) ) - self.reasons.setdefault(new_r, []) - self.replaced_by.setdefault(new_r, []) + self.on_import(fgraph, new_node, reason=reason) + self.on_prune(fgraph, old_node, reason=reason) + + self.reasons.setdefault(new_var, []) + self.replaced_by.setdefault(new_var, []) append_reason = True - for tup in self.reasons[new_r]: - if tup[0] == reason and tup[1] is r: + for tup in self.reasons[new_var]: + if tup[0] == reason and tup[1] is old_var: append_reason = False if append_reason: @@ -1272,12 +1269,12 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): # optimizations will change the graph done = dict() used_ids = dict() - self.reasons[new_r].append( + self.reasons[new_var].append( ( reason, - r, + old_var, _debugprint( - r, + old_var, prefix=" ", depth=6, file=StringIO(), @@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): used_ids=used_ids, ).getvalue(), _debugprint( - new_r, + new_var, prefix=" ", depth=6, file=StringIO(), @@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): ).getvalue(), ) ) - self.replaced_by[r].append((reason, new_r)) + self.replaced_by[old_var].append((reason, new_var)) - if r in self.equiv: - r_set = self.equiv[r] + if old_var in self.equiv: + r_set = self.equiv[old_var] else: - r_set = self.equiv.setdefault(r, {r}) - self.all_variables_ever.append(r) + r_set = self.equiv.setdefault(old_var, {old_var}) + self.all_variables_ever.append(old_var) - if new_r in self.equiv: - new_r_set = self.equiv[new_r] + if new_var in self.equiv: + new_r_set = self.equiv[new_var] else: - new_r_set = self.equiv.setdefault(new_r, {new_r}) - self.all_variables_ever.append(new_r) + new_r_set = self.equiv.setdefault(new_var, {new_var}) + self.all_variables_ever.append(new_var) - assert new_r in new_r_set - assert r in r_set + assert new_var in new_r_set + assert old_var in r_set # update one equivalence set to contain the other # transfer all the elements of the old one to the new one @@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): self.equiv[like_new_r] = r_set assert like_new_r in r_set - assert self.equiv[r] is r_set - assert self.equiv[new_r] is r_set + assert self.equiv[old_var] is r_set + assert self.equiv[new_var] is r_set def printstuff(self): for key in self.equiv: diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index 89f3913fb1..f02a7b1141 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -1,5 +1,4 @@ """Core graph classes.""" -import abc import warnings from collections import deque from copy import copy @@ -32,7 +31,6 @@ from aesara.configdefaults import config from aesara.graph.utils import ( - MetaObject, MethodNotDefined, Scratchpad, TestValueError, @@ -53,32 +51,48 @@ _TypeType = TypeVar("_TypeType", bound="Type") _IdType = TypeVar("_IdType", bound=Hashable) -T = TypeVar("T", bound="Node") +T = TypeVar("T", bound=Union["Apply", "Variable"]) NoParams = object() NodeAndChildren = Tuple[T, Optional[Iterable[T]]] -class Node(MetaObject): - r"""A `Node` in an Aesara graph. +class UniqueInstanceFactory(type): - Currently, graphs contain two kinds of `Nodes`: `Variable`\s and `Apply`\s. - Edges in the graph are not explicitly represented. Instead each `Node` - keeps track of its parents via `Variable.owner` / `Apply.inputs`. + __instances__: WeakValueDictionary - """ - name: Optional[str] + def __new__(cls, name, bases, dct): + dct["__instances__"] = WeakValueDictionary() - def get_parents(self): - """ - Return a list of the parents of this node. - Should return a copy--i.e., modifying the return - value should not modify the graph structure. + if "_post_call" not in dct: - """ - raise NotImplementedError() + def _post_call(self, *args, **kwargs): + return self + + dct["_post_call"] = _post_call + + res = super().__new__(cls, name, bases, dct) + return res + + def __call__( + cls, + *args, + **kwargs, + ): + idp = cls.create_key(*args, **kwargs) + res = cls.__instances__.get(idp) -class Apply(Node, Generic[OpType]): + if res is None: + res = super(UniqueInstanceFactory, cls).__call__(*args, **kwargs) + cls.__instances__[idp] = res + + return res._post_call(*args, **kwargs) + + +class Apply( + Generic[OpType], + metaclass=UniqueInstanceFactory, +): """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -113,12 +127,38 @@ class Apply(Node, Generic[OpType]): """ + __slots__ = ("op", "inputs", "outputs", "__weakref__", "tag") + + @classmethod + def create_key(cls, op, inputs, outputs): + return (op,) + tuple(inputs) + def __init__( self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"], ): + r""" + + Parameters + ---------- + op + The operation that produces `outputs` given `inputs`. + inputs + The arguments of the expression modeled by the `Apply` node. + outputs + The outputs of the expression modeled by the `Apply` node. If a + node already exists for the given `op` and `inputs` combination, + each `Variable` in `outputs` will be associated with the node + (i.e. `Variable.owner` will be (re)set), and the `Apply.outputs` + values for the returned node will consist of the original outputs + and not the new `outputs`. + In other words, `Apply.outputs` is always a consistent, unique list + of `Variable`\s for each `op` and `inputs` pair. + + """ + if not isinstance(inputs, Sequence): raise TypeError("The inputs of an Apply must be a sequence type") @@ -129,7 +169,6 @@ def __init__( self.inputs: List[Variable] = [] self.tag = Scratchpad() - # filter inputs to make sure each element is a Variable for input in inputs: if isinstance(input, Variable): self.inputs.append(input) @@ -137,22 +176,34 @@ def __init__( raise TypeError( f"The 'inputs' argument to Apply must contain Variable instances, not {input}" ) - self.outputs: List[Variable] = [] - # filter outputs to make sure each element is a Variable + + self.outputs: List[Variable] = list(outputs) + + def _post_call(self, op, inputs, outputs): + + # If a user passes new outputs to an existing `Apply` node, those + # outputs will be updated and associated with the node, but the + # returned node's outputs will still be the original `Variable`s. for i, output in enumerate(outputs): - if isinstance(output, Variable): - if output.owner is None: - output.owner = self - output.index = i - elif output.owner is not self or output.index != i: - raise ValueError( - "All output variables passed to Apply must belong to it." - ) - self.outputs.append(output) - else: + if not isinstance(output, Variable): raise TypeError( f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}" ) + output.owner = self + output.index = i + + return self + + def __eq__(self, other): + if isinstance(other, type(self)): + if self.op == other.op and self.inputs == other.inputs: + return True + return False + + return NotImplemented + + def __hash__(self): + return hash((type(self), self.op, tuple(self.inputs))) def run_params(self): """ @@ -165,8 +216,7 @@ def run_params(self): return NoParams def __getstate__(self): - d = self.__dict__ - # ufunc don't pickle/unpickle well + d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)} if hasattr(self.tag, "ufunc"): d = copy(self.__dict__) t = d["tag"] @@ -174,6 +224,11 @@ def __getstate__(self): d["tag"] = t return d + def __setstate__(self, dct): + for k in self.__slots__: + if k in dct: + setattr(self, k, dct[k]) + def default_output(self): """ Returns the default output for this node. @@ -267,6 +322,7 @@ def clone_with_new_inputs( from aesara.graph.op import HasInnerGraph assert isinstance(inputs, (list, tuple)) + remake_node = False new_inputs: List["Variable"] = list(inputs) for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): @@ -280,17 +336,22 @@ def clone_with_new_inputs( else: remake_node = True - if remake_node: - new_op = self.op + new_op = self.op - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore + if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore + new_op = new_op.clone() # type: ignore + if remake_node: new_node = new_op.make_node(*new_inputs) new_node.tag = copy(self.tag).__update__(new_node.tag) + elif new_op == self.op and new_inputs == self.inputs: + new_node = self else: - new_node = self.clone(clone_inner_graph=clone_inner_graph) - new_node.inputs = new_inputs + new_node = self.__class__( + new_op, new_inputs, [output.clone() for output in self.outputs] + ) + new_node.tag = copy(self.tag) + return new_node def get_parents(self): @@ -316,7 +377,7 @@ def params_type(self): return self.op.params_type -class Variable(Node, Generic[_TypeType, OptionalApplyType]): +class Variable(Generic[_TypeType, OptionalApplyType]): r""" A :term:`Variable` is a node in an expression graph that represents a variable. @@ -411,7 +472,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): """ - # __slots__ = ['type', 'owner', 'index', 'name'] + __slots__ = ("_owner", "_index", "name", "type", "__weakref__", "tag", "auto_name") __count__ = count(0) _owner: OptionalApplyType @@ -487,26 +548,17 @@ def __str__(self): else: return f"<{self.type}>" - def __repr_test_value__(self): - """Return a ``repr`` of the test value. - - Return a printable representation of the test value. It can be - overridden by classes with non printable test_value to provide a - suitable representation of the test_value. - """ - return repr(self.get_test_value()) - def __repr__(self, firstPass=True): """Return a ``repr`` of the `Variable`. - Return a printable name or description of the Variable. If - ``config.print_test_value`` is ``True`` it will also print the test - value, if any. + Return a printable name or description of the `Variable`. If + `aesara.config.print_test_value` is ``True``, it will also print the + test value, if any. """ to_print = [str(self)] if config.print_test_value and firstPass: try: - to_print.append(self.__repr_test_value__()) + to_print.append(repr(self.get_test_value())) except TestValueError: pass return "\n".join(to_print) @@ -534,26 +586,6 @@ def clone(self, **kwargs): cp.tag = copy(self.tag) return cp - def __lt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __lt__", self.__class__.__name__ - ) - - def __le__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __le__", self.__class__.__name__ - ) - - def __gt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __gt__", self.__class__.__name__ - ) - - def __ge__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __ge__", self.__class__.__name__ - ) - def get_parents(self): if self.owner is not None: return [self.owner] @@ -611,7 +643,7 @@ def eval(self, inputs_to_values=None): return rval def __getstate__(self): - d = self.__dict__.copy() + d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)} d.pop("_fn_cache", None) if (not config.pickle_test_value) and (hasattr(self.tag, "test_value")): if not type(config).pickle_test_value.is_default: @@ -624,6 +656,11 @@ def __getstate__(self): d["tag"] = t return d + def __setstate__(self, dct): + for k in self.__slots__: + if k in dct: + setattr(self, k, dct[k]) + class AtomicVariable(Variable[_TypeType, None]): """A node type that has no ancestors and should never be considered an input to a graph.""" @@ -631,19 +668,12 @@ class AtomicVariable(Variable[_TypeType, None]): def __init__(self, type: _TypeType, name: Optional[str] = None, **kwargs): super().__init__(type=type, owner=None, index=None, name=name, **kwargs) - @abc.abstractmethod - def signature(self): - ... - - def merge_signature(self): - return self.signature() - def equals(self, other): """ This does what `__eq__` would normally do, but `Variable` and `Apply` should always be hashable by `id`. """ - return isinstance(other, type(self)) and self.signature() == other.signature() + return self == other @property def owner(self): @@ -677,7 +707,10 @@ class NominalVariable(AtomicVariable[_TypeType]): __instances__: WeakValueDictionary = WeakValueDictionary() def __new__(cls, id: _IdType, typ: _TypeType, **kwargs): - if (typ, id) not in cls.__instances__: + + idp = (typ, id) + + if idp not in cls.__instances__: var_type = typ.variable_type type_name = f"Nominal{var_type.__name__}" @@ -692,9 +725,9 @@ def _str(self): ) res: NominalVariable = super().__new__(new_type) - cls.__instances__[(typ, id)] = res + cls.__instances__[idp] = res - return cls.__instances__[(typ, id)] + return cls.__instances__[idp] def __init__(self, id: _IdType, typ: _TypeType, name: Optional[str] = None): self.id = id @@ -720,11 +753,11 @@ def __hash__(self): def __repr__(self): return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})" - def signature(self) -> Tuple[_TypeType, _IdType]: - return (self.type, self.id) - -class Constant(AtomicVariable[_TypeType]): +class Constant( + AtomicVariable[_TypeType], + metaclass=UniqueInstanceFactory, +): """A `Variable` with a fixed `data` field. `Constant` nodes make numerous optimizations possible (e.g. constant @@ -737,19 +770,22 @@ class Constant(AtomicVariable[_TypeType]): """ - # __slots__ = ['data'] + __slots__ = ("type", "data") + + @classmethod + def create_key(cls, type, data, *args, **kwargs): + # TODO FIXME: This filters the data twice: once here, and again in + # `cls.__init__`. This might not be a big deal, though. + return (type, type.filter(data)) def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None): - super().__init__(type, name=name) + AtomicVariable.__init__(self, type, name=name) self.data = type.filter(data) add_tag_trace(self) def get_test_value(self): return self.data - def signature(self): - return (self.type, self.data) - def __str__(self): if self.name is not None: return self.name @@ -775,6 +811,15 @@ def owner(self, value) -> None: def value(self): return self.data + def __hash__(self): + return hash((type(self), self.type, self.data)) + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.type == other.type and self.data == other.data + + return NotImplemented + def walk( nodes: Iterable[T], diff --git a/aesara/graph/destroyhandler.py b/aesara/graph/destroyhandler.py index abc4894715..1f6c4956ce 100644 --- a/aesara/graph/destroyhandler.py +++ b/aesara/graph/destroyhandler.py @@ -14,15 +14,6 @@ from aesara.misc.ordered_set import OrderedSet -class ProtocolError(Exception): - """ - Raised when FunctionGraph calls DestroyHandler callbacks in - an invalid way, for example, pruning or changing a node that has - never been imported. - - """ - - def _contains_cycle(fgraph, orderings): """ Function to check if the given graph contains a cycle @@ -180,7 +171,8 @@ def _build_droot_impact(destroy_handler): impact = {} # destroyed nonview variable -> it + all views of it root_destroyer = {} # root -> destroyer apply - for app in destroy_handler.destroyers: + for ref_out in destroy_handler.destroyers: + app = ref_out.owner for output_idx, input_idx_list in app.op.destroy_map.items(): if len(input_idx_list) != 1: raise NotImplementedError() @@ -250,7 +242,7 @@ def fast_inplace_check(fgraph, inputs): return inputs -class DestroyHandler(Bookkeeper): # noqa +class DestroyHandler(Bookkeeper): """ The DestroyHandler class detects when a graph is impossible to evaluate because of aliasing and destructive operations. @@ -319,8 +311,8 @@ def __init__(self, do_imports_on_attach=True, algo=None): self.impact = OrderedDict() """ - If a var is destroyed, then this dict will map - droot[var] to the apply node that destroyed var + If a ``var`` is destroyed, then this dict will map + ``droot[var]`` to the `Variable` that's owner destroyed ``var`` TODO: rename to vroot_to_destroyer """ @@ -334,21 +326,6 @@ def clone(self): return type(self)(self.do_imports_on_attach, self.algo) def on_attach(self, fgraph): - """ - When attaching to a new fgraph, check that - 1) This DestroyHandler wasn't already attached to some fgraph - (its data structures are only set up to serve one). - 2) The FunctionGraph doesn't already have a DestroyHandler. - This would result in it validating everything twice, causing - compilation to be slower. - - Give the FunctionGraph instance: - 1) A new method "destroyers(var)" - TODO: what does this do exactly? - 2) A new attribute, "destroy_handler" - TODO: WRITEME: what does this do besides the checks? - - """ if any(hasattr(fgraph, attr) for attr in ("destroyers", "destroy_handler")): raise AlreadyThere("DestroyHandler feature is already present") @@ -358,20 +335,28 @@ def on_attach(self, fgraph): "A DestroyHandler instance can only serve one FunctionGraph" ) - # Annotate the FunctionGraph # self.unpickle(fgraph) + fgraph.destroy_handler = self self.fgraph = fgraph - self.destroyers = ( - OrderedSet() - ) # set of Apply instances with non-null destroy_map self.view_i = {} # variable -> variable used in calculation self.view_o = ( {} ) # variable -> set of variables that use this one as a direct input - # clients: how many times does an apply use a given variable - self.clients = OrderedDict() # variable -> apply -> ninputs + + # The following map tracks how many times a variable is referenced by an `Apply` node. + # It doesn't actually use `Apply` nodes, though, because doing so would require + # that we update the `Apply` nodes on every replacement in the graph. Instead, we use + # the first output `Variable` of an `Apply` node. Since `Variable.owner` is updated + # whenever a replacement is made, these representative output `Variable`s will always + # point to the appropriate `Apply` node. + self.clients = OrderedDict() + + # Set of output `Variable`s representing `Apply` nodes (see the + # description for `self.clients`) with non-null `Op.destroy_map`s. + self.destroyers = OrderedSet() + self.stale_droot = True self.debug_all_apps = set() @@ -497,72 +482,75 @@ def fast_destroy(self, fgraph, app, reason): # assert len(v) <= 1 # assert len(d) <= 1 - def on_import(self, fgraph, app, reason): + def on_import(self, fgraph, node, reason): """ Add Apply instance to set which must be computed. """ - if app in self.debug_all_apps: - raise ProtocolError("double import") - self.debug_all_apps.add(app) - # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) + # Choose an output to represent the `Apply` node + rep_out = node.outputs[0] + + if rep_out in self.debug_all_apps: + return + + self.debug_all_apps.add(rep_out) # If it's a destructive op, add it to our watch list - dmap = app.op.destroy_map - vmap = app.op.view_map + dmap = node.op.destroy_map + vmap = node.op.view_map if dmap: - self.destroyers.add(app) + self.destroyers.add(rep_out) if self.algo == "fast": - self.fast_destroy(fgraph, app, reason) + self.fast_destroy(fgraph, node, reason) # add this symbol to the forward and backward maps for o_idx, i_idx_list in vmap.items(): if len(i_idx_list) > 1: raise NotImplementedError( - "destroying this output invalidates multiple inputs", (app.op) + "destroying this output invalidates multiple inputs", (node.op) ) - o = app.outputs[o_idx] - i = app.inputs[i_idx_list[0]] + o = node.outputs[o_idx] + i = node.inputs[i_idx_list[0]] self.view_i[o] = i self.view_o.setdefault(i, OrderedSet()).add(o) - # update self.clients - for i, input in enumerate(app.inputs): - self.clients.setdefault(input, OrderedDict()).setdefault(app, 0) - self.clients[input][app] += 1 + for i, input in enumerate(node.inputs): + self.clients.setdefault(input, OrderedDict()).setdefault(rep_out, 0) + self.clients[input][rep_out] += 1 - for i, output in enumerate(app.outputs): + for i, output in enumerate(node.outputs): self.clients.setdefault(output, OrderedDict()) self.stale_droot = True - def on_prune(self, fgraph, app, reason): + def on_prune(self, fgraph, node, reason): """ Remove Apply instance from set which must be computed. """ - if app not in self.debug_all_apps: - raise ProtocolError("prune without import") - self.debug_all_apps.remove(app) + # Choose an output to represent the `Apply` node + rep_out = node.outputs[0] + + assert rep_out in self.debug_all_apps + + self.debug_all_apps.remove(rep_out) - # UPDATE self.clients - for input in set(app.inputs): - del self.clients[input][app] + for input in set(node.inputs): + del self.clients[input][rep_out] - if app.op.destroy_map: - self.destroyers.remove(app) + if node.op.destroy_map: + self.destroyers.remove(rep_out) # Note: leaving empty client dictionaries in the struct. # Why? It's a pain to remove them. I think they aren't doing any harm, they will be # deleted on_detach(). - # UPDATE self.view_i, self.view_o - for o_idx, i_idx_list in app.op.view_map.items(): + for o_idx, i_idx_list in node.op.view_map.items(): if len(i_idx_list) > 1: # destroying this output invalidates multiple inputs raise NotImplementedError() - o = app.outputs[o_idx] - i = app.inputs[i_idx_list[0]] + o = node.outputs[o_idx] + i = node.inputs[i_idx_list[0]] del self.view_i[o] @@ -571,53 +559,61 @@ def on_prune(self, fgraph, app, reason): del self.view_o[i] self.stale_droot = True - if app in self.fail_validate: - del self.fail_validate[app] + if rep_out in self.fail_validate: + del self.fail_validate[rep_out] - def on_change_input(self, fgraph, app, i, old_r, new_r, reason): - """ - app.inputs[i] changed from old_r to new_r. + def on_change_input( + self, fgraph, old_node, new_node, input_idx, old_var, new_var, reason + ): + """Update the clients and view mappings.""" - """ - if app == "output": + if old_node != "output": # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. - pass - else: - if app not in self.debug_all_apps: - raise ProtocolError("change without import") - - # UPDATE self.clients - self.clients[old_r][app] -= 1 - if self.clients[old_r][app] == 0: - del self.clients[old_r][app] - - self.clients.setdefault(new_r, OrderedDict()).setdefault(app, 0) - self.clients[new_r][app] += 1 - - # UPDATE self.view_i, self.view_o - for o_idx, i_idx_list in app.op.view_map.items(): - if len(i_idx_list) > 1: - # destroying this output invalidates multiple inputs - raise NotImplementedError() + + # Use the first output to represent the `Apply` node. + # N.B. The old node's outputs should be the same as the new node's + # outputs. + rep_out = old_node.outputs[0] + + assert rep_out in self.debug_all_apps + + new_count = self.clients[old_var][rep_out] - 1 + + assert new_count >= 0 + + if new_count == 0: + del self.clients[old_var][rep_out] + else: + self.clients[old_var][rep_out] = new_count + + self.clients.setdefault(new_var, OrderedDict()).setdefault(rep_out, 0) + self.clients[new_var][rep_out] += 1 + + for o_idx, i_idx_list in new_node.op.view_map.items(): + + # Destroying this output would invalidate multiple inputs, and + # that's not currently supported + assert len(i_idx_list) == 1 + i_idx = i_idx_list[0] - output = app.outputs[o_idx] - if i_idx == i: - if app.inputs[i_idx] is not new_r: - raise ProtocolError("wrong new_r on change") + output = new_node.outputs[o_idx] + if i_idx == input_idx: + assert new_node.inputs[i_idx] is new_var - self.view_i[output] = new_r + self.view_i[output] = new_var - self.view_o[old_r].remove(output) - if not self.view_o[old_r]: - del self.view_o[old_r] + self.view_o[old_var].remove(output) + if not self.view_o[old_var]: + del self.view_o[old_var] - self.view_o.setdefault(new_r, OrderedSet()).add(output) + self.view_o.setdefault(new_var, OrderedSet()).add(output) if self.algo == "fast": - if app in self.fail_validate: - del self.fail_validate[app] - self.fast_destroy(fgraph, app, reason) + if rep_out in self.fail_validate: + del self.fail_validate[rep_out] + self.fast_destroy(fgraph, old_node, reason) + self.stale_droot = True def validate(self, fgraph): @@ -632,7 +628,7 @@ def validate(self, fgraph): if self.destroyers: if self.algo == "fast": if self.fail_validate: - app_err_pairs = self.fail_validate + rep_out_err_pairs = self.fail_validate self.fail_validate = OrderedDict() # self.fail_validate can only be a hint that maybe/probably # there is a cycle.This is because inside replace() we could @@ -641,12 +637,14 @@ def validate(self, fgraph): # graph might have already changed when we raise the # self.fail_validate error. So before raising the error, we # double check here. - for app in app_err_pairs: + for rep_out in rep_out_err_pairs: + app = rep_out.owner if app in fgraph.apply_nodes: self.fast_destroy(fgraph, app, "validate") + if self.fail_validate: - self.fail_validate = app_err_pairs - raise app_err_pairs[app] + self.fail_validate = rep_out_err_pairs + raise rep_out_err_pairs[rep_out] else: ords = self.orderings(fgraph, ordered=False) if _contains_cycle(fgraph, ords): @@ -700,13 +698,14 @@ def orderings(self, fgraph, ordered=True): ) # add destroyed variable clients as computational dependencies - for app in self.destroyers: + for rep_out in self.destroyers: + destroyer_node = rep_out.owner # keep track of clients that should run before the current Apply root_clients = set_type() # for each destroyed input... - for output_idx, input_idx_list in app.op.destroy_map.items(): + for output_idx, input_idx_list in destroyer_node.op.destroy_map.items(): destroyed_idx = input_idx_list[0] - destroyed_variable = app.inputs[destroyed_idx] + destroyed_variable = destroyer_node.inputs[destroyed_idx] root = droot[destroyed_variable] root_impact = impact[root] # we generally want to put all clients of things which depend on root @@ -744,27 +743,29 @@ def orderings(self, fgraph, ordered=True): # CHECK FOR INPUT ALIASING # OPT: pre-compute this on import - tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) + tolerate_same = getattr( + destroyer_node.op, "destroyhandler_tolerate_same", [] + ) assert isinstance(tolerate_same, list) tolerated = { idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx } tolerated.add(destroyed_idx) tolerate_aliased = getattr( - app.op, "destroyhandler_tolerate_aliased", [] + destroyer_node.op, "destroyhandler_tolerate_aliased", [] ) assert isinstance(tolerate_aliased, list) ignored = { idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx } - for i, input in enumerate(app.inputs): + for i, input in enumerate(destroyer_node.inputs): if i in ignored: continue if input in root_impact and ( i not in tolerated or input is not destroyed_variable ): raise InconsistencyError( - f"Input aliasing: {app} ({destroyed_idx}, {i})" + f"Input aliasing: {destroyer_node} ({destroyed_idx}, {i})" ) # add the rule: app must be preceded by all other Apply instances that @@ -777,8 +778,8 @@ def orderings(self, fgraph, ordered=True): # app itself is a client of the destroyed inputs, # but should not run before itself - root_clients.remove(app) + root_clients.remove(rep_out) if root_clients: - rval[app] = root_clients + rval[destroyer_node] = root_clients return rval diff --git a/aesara/graph/features.py b/aesara/graph/features.py index 73a625409f..af33468438 100644 --- a/aesara/graph/features.py +++ b/aesara/graph/features.py @@ -5,6 +5,7 @@ from collections import OrderedDict from functools import partial from io import StringIO +from typing import TYPE_CHECKING, Mapping, Optional, Sequence import numpy as np @@ -14,6 +15,11 @@ from aesara.graph.utils import InconsistencyError +if TYPE_CHECKING: + from aesara.graph.basic import Apply + from aesara.graph.fg import FunctionGraph + + class AlreadyThere(Exception): """ Raised by a Feature's on_attach callback method if the FunctionGraph @@ -262,31 +268,31 @@ class Feature: """ - def on_attach(self, fgraph): - """ + def on_attach(self, fgraph) -> None: + """Handle the association of an `FunctionGraph` with this `Feature`. + Called by `FunctionGraph.attach_feature`, the method that attaches the feature to the `FunctionGraph`. Since this is called after the `FunctionGraph` is initially populated, this is where you should run checks on the initial contents of the `FunctionGraph`. - The on_attach method may raise the `AlreadyThere` exception to cancel - the attach operation if it detects that another Feature instance - implementing the same functionality is already attached to the + This method may raise an `AlreadyThere` exception to cancel the + attachment operation, e.g. if it detects that another `Feature` + instance implementing the same functionality is already attached to the `FunctionGraph`. - The feature has great freedom in what it can do with the `fgraph`: it - may, for example, add methods to it dynamically. - """ - def on_detach(self, fgraph): + def on_detach(self, fgraph: "FunctionGraph") -> None: """ Called by `FunctionGraph.remove_feature`. Should remove any dynamically-added functionality that it installed into the fgraph. """ - def on_import(self, fgraph, node, reason): + def on_import( + self, fgraph: "FunctionGraph", node: "Apply", reason: Optional[str] = None + ) -> None: """ Called whenever a node is imported into `fgraph`, which is just before the node is actually connected to the graph. @@ -297,36 +303,55 @@ def on_import(self, fgraph, node, reason): """ - def on_change_input(self, fgraph, node, i, var, new_var, reason=None): - """ - Called whenever ``node.inputs[i]`` is changed from `var` to `new_var`. - At the moment the callback is done, the change has already taken place. + def on_change_input( + self, + fgraph: "FunctionGraph", + old_node: "Apply", + new_node: "Apply", + i: int, + old_var: Variable, + new_var: Variable, + reason: Optional[str] = None, + ) -> None: + """Handle node and input replacements. - If you raise an exception in this function, the state of the graph - might be broken for all intents and purposes. + This is called whenever ``node.inputs[i]`` is changed from `old_var` to + `new_var`, and, since `Apply` nodes represent a distinct set of inputs, + a new node is created to replace the old one. - """ + When this method is called, the change has already been made. + + Warning: If an exception is raised in this function, the state of the + graph could become invalid. - def on_prune(self, fgraph, node, reason): """ + + def on_prune( + self, fgraph: "FunctionGraph", node: "Apply", reason: Optional[str] = None + ) -> None: + """Handle removal of an `Apply` node. + Called whenever a node is pruned (removed) from the `fgraph`, after it is disconnected from the graph. """ - def orderings(self, fgraph): - """ - Called by `FunctionGraph.toposort`. It should return a dictionary of + def orderings(self, fgraph: "FunctionGraph") -> Mapping["Apply", Sequence["Apply"]]: + """Return a dictionary mapping nodes to their predecessors. + + It should return a dictionary of ``{node: predecessors}`` where ``predecessors`` is a list of nodes that should be computed before the key node. - If you raise an exception in this function, the state of the graph - might be broken for all intents and purposes. + This is called by `FunctionGraph.toposort`. + + Warning: If an exception is raised in this function, the state of the + graph could become invalid. """ return OrderedDict() - def clone(self): + def clone(self) -> "Feature": """Create a clone that can be attached to a new `FunctionGraph`. This default implementation returns `self`, which carries the @@ -361,16 +386,23 @@ def __call__(self): class LambdaExtract: - def __init__(self, fgraph, node, i, r, reason=None): + """A class that represents `change_node_input` calls.""" + + def __init__(self, fgraph, old_node, new_node, i, old_var, reason=None): self.fgraph = fgraph - self.node = node + self.old_node = old_node + self.new_node = new_node self.i = i - self.r = r + self.old_var = old_var self.reason = reason def __call__(self): return self.fgraph.change_node_input( - self.node, self.i, self.r, reason=("Revert", self.reason), check=False + self.new_node, + self.i, + self.old_var, + reason=f"Revert: {self.reason}", + check=False, ) @@ -417,11 +449,13 @@ def on_detach(self, fgraph): del fgraph.revert del self.history[fgraph] - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): if self.history[fgraph] is None: return h = self.history[fgraph] - h.append(LambdaExtract(fgraph, node, i, r, reason)) + h.append(LambdaExtract(fgraph, old_node, new_node, i, old_var, reason)) def revert(self, fgraph, checkpoint): """ @@ -742,9 +776,13 @@ def on_prune(self, fgraph, node, reason): if self.active: print(f"-- pruning: {node}, reason: {reason}") - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): if self.active: - print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") + print( + f"-- changing {old_node}.inputs[{i}] from {old_var} to {new_var} resulting in {new_node}" + ) class PreserveVariableAttributes(Feature): @@ -752,14 +790,16 @@ class PreserveVariableAttributes(Feature): This preserve some variables attributes and tag during optimization. """ - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - if r.name is not None and new_r.name is None: - new_r.name = r.name + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): + if old_var.name is not None and new_var.name is None: + new_var.name = old_var.name if ( - getattr(r.tag, "nan_guard_mode_check", False) - and getattr(new_r.tag, "nan_guard_mode_check", False) is False + getattr(old_var.tag, "nan_guard_mode_check", False) + and getattr(new_var.tag, "nan_guard_mode_check", False) is False ): - new_r.tag.nan_guard_mode_check = r.tag.nan_guard_mode_check + new_var.tag.nan_guard_mode_check = old_var.tag.nan_guard_mode_check class NoOutputFromInplace(Feature): diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 56c999f871..57800a3457 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -206,7 +206,11 @@ def add_client(self, var: Variable, new_client: ClientType) -> None: raise TypeError( 'The first entry of `new_client` must be an `Apply` node or the string `"output"`' ) - self.clients[var].append(new_client) + var_clients = self.clients[var] + # TODO: This might be another reason to use a type like + # `Dict[Variable, Set[Tuple[Apply, int]]]` for `FeatureGraph.clients` + if new_client not in var_clients: + var_clients.append(new_client) def remove_client( self, @@ -412,15 +416,15 @@ def change_node_input( reason: Optional[str] = None, import_missing: bool = False, check: bool = True, - ) -> None: - """Change ``node.inputs[i]`` to `new_var`. + ) -> Optional[Apply]: + """Create a clone of `node` in which ``node.inputs[i]`` is equal to `new_var`. ``new_var.type.is_super(old_var.type)`` must be ``True``, where ``old_var`` is the current value of ``node.inputs[i]`` which we want to replace. - For each feature that has an `on_change_input` method, this method calls: - ``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)`` + For each feature that has an `Feature.on_change_input` method, this method calls: + ``feature.on_change_input(function_graph, old_node, new_node, i, old_var, new_var, reason)`` Parameters ---------- @@ -440,35 +444,129 @@ def change_node_input( `History` `Feature`, which needs to revert types that have been narrowed and would otherwise fail this check. """ - # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) - if node == "output": - r = self.outputs[i] - if check and not r.type.is_super(new_var.type): + + is_output = node == "output" + + if is_output: + old_var = self.outputs[i] + + if old_var is new_var: + return None + + if check and not old_var.type.is_super(new_var.type): raise TypeError( f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." + f"compatible with the type of the original Variable ({old_var.type})." ) + self.outputs[i] = new_var + new_node: Optional[Apply] = new_var.owner + + self.import_var(new_var, reason=reason, import_missing=import_missing) + self.add_client(new_var, (node, i)) + self.remove_client(old_var, (node, i), reason=reason) + self.execute_callbacks( + "on_change_input", node, node, i, old_var, new_var, reason=reason + ) else: assert isinstance(node, Apply) - r = node.inputs[i] - if check and not r.type.is_super(new_var.type): + old_var = node.inputs[i] + + if old_var is new_var: + return None + + if check and not old_var.type.is_super(new_var.type): raise TypeError( f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." + f"compatible with the type of the original Variable ({old_var.type})." ) - node.inputs[i] = new_var - if r is new_var: - return + self.import_var(new_var, reason=reason, import_missing=import_missing) + + # In this case, we need to construct a new `Apply` node with + # `node.inputs[i] = new_var` + new_inputs = list(node.inputs) + new_inputs[i] = new_var + + # By passing `node.outputs` we're assigning those variables + # to this new node (i.e. by resetting `Variable.owner`). + # TODO: Perhaps a `change_owner` callback would be suitable. + new_node = Apply(node.op, new_inputs, node.outputs) + + old_outputs = new_node.outputs + new_node.outputs = node.outputs + + # This is just a sanity check + assert all(o.owner is new_node for o in node.outputs) + + # Next, we need to swap the old `node` with `new_node` in + # `FunctionGraph.clients`, as well as remove any now unused + # nodes and variables induced by the replacement itself. + + if new_node in self.apply_nodes: + # In this case, `new_node` isn't actually new to the graph, so + # all the entries connecting `new_node.inputs` to `new_node` + # are already present in `FunctionGraph.clients`. All we need + # to do is replace references to `new_node.outputs` (i.e. the + # pre-existing node) with `node.outputs`. + for old_out in old_outputs: + for o_node, o_i in self.clients[old_out]: + self.apply_nodes.remove( + o_node if o_node != "output" else self.outputs[o_i].owner + ) + + del self.clients[old_out] + self.variables.remove(old_out) + + else: + self.apply_nodes.add(new_node) + # self._import_node(new_node, reason=reason) + + self.add_client(new_var, (new_node, i)) + + # We need to replace all client references to the old node with the + # new node + for j, inp in enumerate(node.inputs): + if j != i: + self.add_client(inp, (new_node, j)) + # The old variable and node needs to be removed + self.remove_client( + inp, (node, j), reason=reason, remove_if_empty=True + ) + + # TODO: If we know that no intermediate nodes need to be + # removed, then we could perform the node replacements much + # more efficiently + # old_clients = self.clients[inp] + # # TODO: Were the clients list a `dict` mapping nodes to input + # # positions, we could simplify this considerably. + # for k, (client_, input_id) in enumerate(old_clients): + # # client = self.outputs[input_id] if client_ == "output" else client_ + # if client_ == node: + # old_clients[k] = (new_node, input_id) + + self.apply_nodes.remove(node) + + if not hasattr(node.tag, "removed_by"): + node.tag.removed_by = [] + + node.tag.removed_by.append(str(reason)) + + # This is here to simulate the old behavior + self.execute_callbacks("on_prune", node, reason) + self.execute_callbacks("on_import", new_node, reason) + + self.execute_callbacks( + "on_change_input", + node, + new_node, + i, + old_var, + new_var, + reason=reason, + ) - self.import_var(new_var, reason=reason, import_missing=import_missing) - self.add_client(new_var, (node, i)) - self.remove_client(r, (node, i), reason=reason) - # Precondition: the substitution is semantically valid However it may - # introduce cycles to the graph, in which case the transaction will be - # reverted later. - self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason) + return new_node def replace( self, @@ -532,10 +630,16 @@ def replace( f"test value. Original: {tval_shape}, new: {new_tval_shape}" ) + new_nodes: Dict[ApplyOrOutput, ApplyOrOutput] = {} for node, i in list(self.clients[var]): - self.change_node_input( - node, i, new_var, reason=reason, import_missing=import_missing + new_node = self.change_node_input( + new_nodes.get(node, node), + i, + new_var, + reason=reason, + import_missing=import_missing, ) + new_nodes[node] = new_node or node def replace_all(self, pairs: Iterable[Tuple[Variable, Variable]], **kwargs) -> None: """Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list.""" diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py index 135581820b..ac421e68d4 100644 --- a/aesara/graph/rewriting/basic.py +++ b/aesara/graph/rewriting/basic.py @@ -33,7 +33,7 @@ from aesara.graph.features import AlreadyThere, Feature, NodeFinder from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.utils import AssocList, InconsistencyError +from aesara.graph.utils import InconsistencyError from aesara.misc.ordered_set import OrderedSet from aesara.utils import flatten @@ -531,8 +531,7 @@ def on_attach(self, fgraph): fgraph.merge_feature = self self.seen_atomics = set() - self.atomic_sig = AssocList() - self.atomic_sig_inv = AssocList() + self.canonical_atomics = {} # For all Apply nodes # Set of distinct (not mergeable) nodes @@ -562,15 +561,15 @@ def on_attach(self, fgraph): def clone(self): return type(self)() - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if node in self.nodes_seen: + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): + if old_node in self.nodes_seen: # If inputs to a node change, it's not guaranteed that the node is # distinct from the other nodes in `self.nodes_seen`. - self.nodes_seen.discard(node) - self.process_node(fgraph, node) + self.nodes_seen.discard(old_node) + self.process_node(fgraph, new_node) - if isinstance(new_r, AtomicVariable): - self.process_atomic(fgraph, new_r) + if isinstance(new_var, AtomicVariable): + self.process_atomic(fgraph, new_var) def on_import(self, fgraph, node, reason): for c in node.inputs: @@ -586,17 +585,14 @@ def on_prune(self, fgraph, node, reason): for c in node.inputs: if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1: # This was the last node using this constant - sig = self.atomic_sig[c] - self.atomic_sig.discard(c) - self.atomic_sig_inv.discard(sig) + self.canonical_atomics.pop(c) self.seen_atomics.discard(id(c)) def process_atomic(self, fgraph, c): """Check if an atomic `c` can be merged, and queue that replacement.""" if id(c) in self.seen_atomics: return - sig = c.merge_signature() - other_c = self.atomic_sig_inv.get(sig, None) + other_c = self.canonical_atomics.get(c, None) if other_c is not None: # multiple names will clobber each other.. # we adopt convention to keep the last name @@ -605,8 +601,7 @@ def process_atomic(self, fgraph, c): self.scheduled.append([[(c, other_c, "merge")]]) else: # this is a new constant - self.atomic_sig[c] = sig - self.atomic_sig_inv[sig] = c + self.canonical_atomics[c] = c self.seen_atomics.add(id(c)) def process_node(self, fgraph, node): @@ -1662,9 +1657,9 @@ def on_prune(self, fgraph, node, reason): if self.pruner: self.pruner(node) - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): if self.chin: - self.chin(node, i, r, new_r, reason) + self.chin(old_node, new_node, i, old_var, new_var, reason) def on_detach(self, fgraph): # To allow pickling this object @@ -1798,7 +1793,7 @@ def attach_updater( if self.ignore_newtrees: importer = None - if importer is None and pruner is None: + if importer is None and pruner is None and chin is None: return None u = DispatchingFeature(importer, pruner, chin, name=name) @@ -1909,7 +1904,7 @@ def process_node( return False try: fgraph.replace_all_validate_remove( # type: ignore - repl_pairs, reason=node_rewriter, remove=remove + repl_pairs, reason=str(node_rewriter), remove=remove ) return True except Exception as e: @@ -1966,8 +1961,11 @@ def importer(node): if node is not current_node: q.append(node) + def change_input(old_node, new_node, i, old_var, new_var, reason): + q.append(new_node) + u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) + fgraph, importer, None, chin=change_input, name=getattr(self, "name", None) ) nb = 0 try: @@ -2108,12 +2106,19 @@ def apply(self, fgraph): q = list(fgraph.get_nodes(op)) def importer(node): - if node is not current_node: - if node.op == op: - q.append(node) + if node is not current_node and node.op == op: + q.append(node) + + def change_input(old_node, new_node, i, r, new_r, reason): + if ( + node is not current_node + and isinstance(new_node, Apply) + and new_node.op == op + ): + q.append(new_node) u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) + fgraph, importer, None, chin=change_input, name=getattr(self, "name", None) ) try: while q: @@ -2142,7 +2147,7 @@ def on_import(self, fgraph, node, reason): self.nb_imported += 1 self.changed = True - def on_change_input(self, fgraph, node, i, r, new_r, reason): + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): self.changed = True def reset(self): @@ -2357,15 +2362,24 @@ def importer(node): if node is not current_node: q.append(node) - chin = None if self.tracks_on_change_inputs: - def chin(node, i, r, new_r, reason): - if node is not current_node and not isinstance(node, str): - q.append(node) + def change_input(old_node, new_node, i, r, new_r, reason): + if old_node is not current_node and isinstance(new_node, Apply): + q.append(new_node) + + else: + + def change_input(old_node, new_node, i, r, new_r, reason): + if isinstance(new_node, Apply): + q.append(new_node) u = self.attach_updater( - fgraph, importer, None, chin=chin, name=getattr(self, "name", None) + fgraph, + importer, + None, + chin=change_input, + name=getattr(self, "name", None), ) try: while q: diff --git a/aesara/link/c/basic.py b/aesara/link/c/basic.py index 8aed25cd13..f601404086 100644 --- a/aesara/link/c/basic.py +++ b/aesara/link/c/basic.py @@ -1416,15 +1416,13 @@ def in_sig(i, topological_pos, i_idx): # yield a 'position' that reflects its role in code_gen() if isinstance(i, AtomicVariable): # orphans if id(i) not in constant_ids: - isig = (i.signature(), topological_pos, i_idx) + isig = (hash(i), topological_pos, i_idx) # If the Aesara constant provides a strong hash # (no collision for transpose, 2, 1, 0, -1, -2, # 2 element swapped...) we put this hash in the signature # instead of the value. This makes the key file much # smaller for big constant arrays. Before this, we saw key # files up to 80M. - if hasattr(isig[0], "aesara_hash"): - isig = (isig[0].aesara_hash(), topological_pos, i_idx) try: hash(isig) except Exception: diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index 928b92ed2c..6e4007ef1c 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -291,9 +291,11 @@ def __hash__(self): # NB: For writing, we must bypass setattr() which is always called by default by Python. self.__dict__["__signatures__"] = tuple( # NB: Params object should have been already filtered. - self.__params_type__.types[i] - .make_constant(self[self.__params_type__.fields[i]]) - .signature() + hash( + self.__params_type__.types[i].make_constant( + self[self.__params_type__.fields[i]] + ) + ) for i in range(self.__params_type__.length) ) return hash((type(self), self.__params_type__) + self.__signatures__) diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 33632fa1a6..e3dc2f4c9e 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -292,15 +292,7 @@ def __setstate__(self, dct): class CDataTypeConstant(Constant[T]): - def merge_signature(self): - # We don't want to merge constants that don't point to the - # same object. - return id(self.data) - - def signature(self): - # There is no way to put the data in the signature, so we - # don't even try - return (self.type,) + pass CDataType.constant_type = CDataTypeConstant diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index 46ac71d8ce..46bda1dc93 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -24,7 +24,6 @@ from aesara.link.c.type import generic from aesara.misc.safe_asarray import _asarray from aesara.sparse.type import SparseTensorType, _is_sparse -from aesara.sparse.utils import hash_from_sparse from aesara.tensor import basic as at from aesara.tensor.basic import Split from aesara.tensor.math import _conj @@ -441,35 +440,33 @@ def __repr__(self): return str(self) -class SparseConstantSignature(tuple): +class SparseConstant(TensorConstant, _sparse_py_operators): + format = property(lambda self: self.type.format) + + # def __init__(self, *args): + # .view(HashableNDArray) + def __eq__(self, other): - (a, b), (x, y) = self, other - return ( - a == x - and (b.dtype == y.dtype) - and (type(b) == type(y)) - and (b.shape == y.shape) - and (abs(b - y).sum() < 1e-6 * b.nnz) - ) + if isinstance(other, type(self)): + b = self.data + y = other.data + if ( + self.type == other.type + and (b.dtype == y.dtype) + and (type(b) == type(y)) + and (b.shape == y.shape) + and (abs(b - y).sum() < 1e-6 * b.nnz) + ): + return True + return False + + return NotImplemented def __ne__(self, other): return not self == other def __hash__(self): - (a, b) = self - return hash(type(self)) ^ hash(a) ^ hash(type(b)) - - def aesara_hash(self): - (_, d) = self - return hash_from_sparse(d) - - -class SparseConstant(TensorConstant, _sparse_py_operators): - format = property(lambda self: self.type.format) - - def signature(self): - assert self.data is not None - return SparseConstantSignature((self.type, self.data)) + return hash((type(self), self.type, self.data)) def __str__(self): return "{}{{{},{},shape={},nnz={}}}".format( diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 4762d903d2..7009d60a01 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -228,6 +228,8 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: ttype = TensorType(dtype=x_.dtype, shape=x_.shape) + x_.setflags(write=0) + return TensorConstant(ttype, x_, name=name) diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index 87d77b1322..da0ea410d4 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -1,6 +1,6 @@ import traceback from io import StringIO -from typing import Optional +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import cast as type_cast from warnings import warn @@ -51,6 +51,10 @@ from aesara.tensor.type_other import NoneConst +if TYPE_CHECKING: + from aesara.graph.basic import Apply + + class ShapeFeature(Feature): r"""A `Feature` that tracks shape information in a graph. @@ -366,8 +370,8 @@ def set_shape(self, r, s, override=False): assert all( not hasattr(r.type, "shape") or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) + or self.lscalar_one == shape_vars[i] + or self.lscalar_one == extract_constant(shape_vars[i]) for i in range(r.type.ndim) ) self.shape_of[r] = tuple(shape_vars) @@ -508,13 +512,13 @@ def on_attach(self, fgraph): self.lscalar_one = constant(1, dtype="int64") assert self.lscalar_one.type.dtype == "int64" - self.fgraph = fgraph + self.fgraph: FunctionGraph = fgraph # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} + self.shape_of: Dict[Variable, Optional[Tuple[Variable]]] = {} # Variable -> - self.scheduled = {} + self.scheduled: Dict["Apply", Variable] = {} # shape var -> graph v - self.shape_of_reverse_index = {} + self.shape_of_reverse_index: Dict[Variable, Set[Variable]] = {} for node in fgraph.toposort(): self.on_import(fgraph, node, reason="on_attach") @@ -586,34 +590,37 @@ def on_import(self, fgraph, node, reason): for r, s in zip(node.outputs, o_shapes): self.set_shape(r, s) - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) + def on_change_input(self, fgraph, old_node, new_node, i, old_var, new_var, reason): + if new_var not in self.shape_of: + # It happen that the fgraph didn't call `ShapeFeature.on_import` for some + # `new_var`. This can happen when `new_var` doesn't have an + # owner (i.e. it is a constant or an input of the graph). + # FYI: `ShapeFeature.update_shape` suppose that `old_var` and `new_var` are in shape_of. + self.init_r(new_var) - # This tells us that r and new_r must have the same shape if + # This tells us that `old_var` and `new_var` must have the same shape if # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) + self.update_shape(new_var, old_var) - # change_input happens in two cases: - # 1) we are trying to get rid of r, or + # Let's consider two (mutually exclusive?) cases: + # 1) we are trying to get rid of `old_var`, or # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that - # r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - for (shpnode, idx) in fgraph.clients[r] + [(node, i)]: + # + # In case 1, if `old_var` has a `ShapeFeature.shape_i` client, we will want to + # replace the shape_i of `old_var` with the shape of `new_var` (i.e. we say that + # `old_var` is *scheduled*). + # + # At that point, `old_node` is no longer a client of `old_var`, and all the clients + # of `old_node` now belong to `new_node`. + + for (shpnode, idx) in fgraph.clients.get(old_var, []) + [(new_node, i)]: if isinstance(getattr(shpnode, "op", None), Shape_i): idx = shpnode.op.i - repl = self.shape_of[new_r][idx] + repl = self.shape_of[new_var][idx] if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. + # This means the replacement shape object is exactly the + # same as the current shape object, so no need for + # replacement. continue if ( repl.owner @@ -629,30 +636,31 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): if shpnode.outputs[0] in ancestors([repl]): raise InconsistencyError( "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" + f"old_node: {old_node}, new_node: {new_node}, i: {i}, " + f"old_var: {old_var}, new_var: {new_var}" ) - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, + self.scheduled[shpnode] = new_var + # In case 2, if `old_var` is a variable that we've scheduled for shape update, # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] + unscheduled = [k for k, v in self.scheduled.items() if v == old_var] for k in unscheduled: del self.scheduled[k] - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): + # In either case, `old_var` could be in shape_of.values(), that is, + # `old_var` itself is the shape of something. In that case, we want to + # update the value in shape_of, to keep it up-to-date. + for v in self.shape_of_reverse_index.get(old_var, ()): # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + # deletion of variables or `Feature.on_change_input`, so it might + # be the case that there are a few extra `v`'s in it that no longer + # have a shape of `old_var` or possibly have been deleted from + # `ShapeFeature.shape_of` entirely. The important thing is that it + # permits to recall all variables with `old_var` in their shape. + for ii, svi in enumerate(self.shape_of.get(v, ())): + if svi == old_var: + self.set_shape_i(v, ii, new_var) + self.shape_of_reverse_index[old_var] = set() def same_shape( self, @@ -684,10 +692,10 @@ def same_shape( return False if dim_x is not None: - sx = [sx[dim_x]] + sx = (sx[dim_x],) if dim_y is not None: - sy = [sy[dim_y]] + sy = (sy[dim_y],) if len(sx) != len(sy): return False @@ -710,11 +718,7 @@ def same_shape( rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), ) canon_shapes = canon_shapes_fg.outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - - for dx, dy in zip(sx, sy): + for dx, dy in zip(canon_shapes[: len(sx)], canon_shapes[len(sx) :]): if not equal_computations([dx], [dy]): return False diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 5890b6e22e..74d7ae02a7 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -12,7 +12,7 @@ from aesara.graph.utils import MetaType from aesara.link.c.type import CType from aesara.misc.safe_asarray import _asarray -from aesara.utils import apply_across_args +from aesara.utils import HashableNDArray, apply_across_args if TYPE_CHECKING: @@ -64,7 +64,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): filter_checks_isfinite = False """ When this is ``True``, strict filtering rejects data containing - ``numpy.nan`` or ``numpy.inf`` entries. (Used in `DebugMode`) + `numpy.nan` or `numpy.inf` entries. (Used in `DebugMode`) """ def __init__( @@ -253,6 +253,13 @@ def filter(self, data, strict=False, allow_downcast=None): if self.filter_checks_isfinite and not np.all(np.isfinite(data)): raise ValueError("Non-finite elements not allowed") + + if not isinstance(data, HashableNDArray): + return data.view(HashableNDArray) + + # Make sure it's read-only so that we can cache hash values and such + data.setflags(write=0) + return data def filter_variable(self, other, allow_convert=True): diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index e0c438c5e5..00c7ed3048 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -57,12 +57,25 @@ def clone(self, **kwargs): def filter(self, x, strict=False, allow_downcast=None): if isinstance(x, slice): + + if isinstance(x.start, np.ndarray): + assert str(x.start.dtype) in integer_dtypes + x = slice(x.start.item(), x.stop, x.step) + + if isinstance(x.stop, np.ndarray): + assert str(x.stop.dtype) in integer_dtypes + x = slice(x.start, x.stop.item(), x.step) + + if isinstance(x.step, np.ndarray): + assert str(x.step.dtype) in integer_dtypes + x = slice(x.start, x.stop, x.step.item()) + return x else: raise TypeError("Expected a slice!") def __str__(self): - return "slice" + return f"{type(self)}()" def __eq__(self, other): return type(self) == type(other) @@ -80,25 +93,23 @@ def may_share_memory(a, b): class SliceConstant(Constant): + @classmethod + def create_key(cls, type, data, *args, **kwargs): + return (type, data.start, data.stop, data.step) + def __init__(self, type, data, name=None): - assert isinstance(data, slice) - # Numpy ndarray aren't hashable, so get rid of them. - if isinstance(data.start, np.ndarray): - assert data.start.ndim == 0 - assert str(data.start.dtype) in integer_dtypes - data = slice(int(data.start), data.stop, data.step) - elif isinstance(data.stop, np.ndarray): - assert data.stop.ndim == 0 - assert str(data.stop.dtype) in integer_dtypes - data = slice(data.start, int(data.stop), data.step) - elif isinstance(data.step, np.ndarray): - assert data.step.ndim == 0 - assert str(data.step.dtype) in integer_dtypes - data = slice(data.start, int(data.stop), data.step) - Constant.__init__(self, type, data, name) - - def signature(self): - return (SliceConstant, self.data.start, self.data.stop, self.data.step) + super().__init__(type, data, name) + + def __eq__(self, other): + if isinstance(other, type(self)): + if self.data == other.data: + return True + return False + + return NotImplemented + + def __hash__(self): + return hash(self.data.__reduce__()) def __str__(self): return "{}{{{}, {}, {}}}".format( diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index 8b281e6bd0..3a7d3bb02d 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -1,4 +1,3 @@ -import copy import traceback as tb import warnings from collections.abc import Iterable @@ -16,7 +15,6 @@ from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.type import TensorType from aesara.tensor.type_other import NoneConst -from aesara.tensor.utils import hash_from_ndarray _TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType) @@ -877,119 +875,6 @@ def _get_vector_length_TensorVariable(op_or_var, var): TensorType.variable_type = TensorVariable -class TensorConstantSignature(tuple): - r"""A signature object for comparing `TensorConstant` instances. - - An instance is a pair with the type ``(Type, ndarray)``. - - TODO FIXME: Subclassing `tuple` is unnecessary, and it appears to be - preventing the use of a much more convenient `__init__` that removes the - need for all these lazy computations and their safety checks. - - Also, why do we even need this signature stuff? We could simply implement - good `Constant.__eq__` and `Constant.__hash__` implementations. - - We could also produce plain `tuple`\s with hashable values. - - """ - - def __eq__(self, other): - if type(self) != type(other): - return False - try: - (t0, d0), (t1, d1) = self, other - except Exception: - return False - - # N.B. compare shape to ensure no broadcasting in == - if t0 != t1 or d0.shape != d1.shape: - return False - - self.no_nan # Ensure has_nan is computed. - # Note that in the comparisons below, the elementwise comparisons - # come last because they are the most expensive checks. - if self.has_nan: - other.no_nan # Ensure has_nan is computed. - return ( - other.has_nan - and self.sum == other.sum - and (self.no_nan.mask == other.no_nan.mask).all() - and - # Note that the second test below (==) may crash e.g. for - # a single scalar NaN value, so we do not run it when all - # values are missing. - (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) - ) - else: - # Simple case where we do not need to worry about NaN values. - # (note that if there are NaN values in d1, this will return - # False, which is why we do not bother with testing `other.has_nan` - # here). - return (self.sum == other.sum) and np.all(d0 == d1) - - def __ne__(self, other): - return not self == other - - def __hash__(self): - t, d = self - return hash((type(self), t, d.shape, self.sum)) - - def aesara_hash(self): - _, d = self - return hash_from_ndarray(d) - - @property - def sum(self): - """Compute sum of non NaN / Inf values in the array.""" - try: - return self._sum - except AttributeError: - - # Prevent warnings when there are `inf`s and `-inf`s present - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - self._sum = self.no_nan.sum() - - # The following 2 lines are needed as in Python 3.3 with NumPy - # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. - if isinstance(self._sum, np.memmap): - self._sum = np.asarray(self._sum).item() - - if self.has_nan and self.no_nan.mask.all(): - # In this case the sum is not properly computed by numpy. - self._sum = 0 - - if np.isinf(self._sum) or np.isnan(self._sum): - # NaN may happen when there are both -inf and +inf values. - if self.has_nan: - # Filter both NaN and Inf values. - mask = self.no_nan.mask + np.isinf(self[1]) - else: - # Filter only Inf values. - mask = np.isinf(self[1]) - if mask.all(): - self._sum = 0 - else: - self._sum = np.ma.masked_array(self[1], mask).sum() - # At this point there should be no more NaN. - assert not np.isnan(self._sum) - - if isinstance(self._sum, np.ma.core.MaskedConstant): - self._sum = 0 - - return self._sum - - @property - def no_nan(self): - try: - return self._no_nan - except AttributeError: - nans = np.isnan(self[1]) - self._no_nan = np.ma.masked_array(self[1], nans) - self.has_nan = np.any(nans) - return self._no_nan - - def get_unique_value(x: TensorVariable) -> Optional[Number]: """Return the unique value of a tensor, if there is one""" if isinstance(x, Constant): @@ -998,7 +883,7 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]: if isinstance(data, np.ndarray) and data.ndim > 0: flat_data = data.ravel() if flat_data.shape[0]: - if (flat_data == flat_data[0]).all(): + if np.all(flat_data == flat_data[0]): return flat_data[0] return None @@ -1022,6 +907,8 @@ def __init__(self, type: _TensorTypeType, data, name=None): assert not any(s is None for s in new_type.shape) + data.setflags(write=0) + Constant.__init__(self, new_type, data, name) def __str__(self): @@ -1039,31 +926,6 @@ def __str__(self): name = "TensorConstant" return "%s{%s}" % (name, val) - def signature(self): - return TensorConstantSignature((self.type, self.data)) - - def equals(self, other): - # Override Constant.equals to allow to compare with - # numpy.ndarray, and python type. - if isinstance(other, (np.ndarray, int, float)): - # Make a TensorConstant to be able to compare - other = at.basic.constant(other) - return ( - isinstance(other, TensorConstant) and self.signature() == other.signature() - ) - - def __copy__(self): - # We need to do this to remove the cached attribute - return type(self)(self.type, self.data, self.name) - - def __deepcopy__(self, memo): - # We need to do this to remove the cached attribute - return type(self)( - copy.deepcopy(self.type, memo), - copy.deepcopy(self.data, memo), - copy.deepcopy(self.name, memo), - ) - TensorType.constant_type = TensorConstant diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index d0d85030a6..0350ed28bd 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -42,7 +42,9 @@ class AssertNoChanges(Feature): """A `Feature` that raises an error when nodes are changed in a graph.""" - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + def on_change_input( + self, fgraph, old_node, new_node, i, old_var, new_var, reason=None + ): raise AssertionError() @@ -591,14 +593,9 @@ def local_rewrite_2(fgraph, node): capres = capsys.readouterr() assert capres.err == "" - assert ( - "rewriting: rewrite local_rewrite_1 replaces node Op1(x, y) with [Op2.0]" - in capres.out - ) - assert ( - "rewriting: rewrite local_rewrite_2 replaces node Op2(y, y) with [Op2.0]" - in capres.out - ) + out1, out2 = capres.out.split("\n", maxsplit=1) + assert out1.startswith("rewriting: rewrite local_rewrite_1 replaces") + assert out2.startswith("rewriting: rewrite local_rewrite_2 replaces") def test_node_rewriter_str(): diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index cdd362b00b..aae868b259 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -8,6 +8,7 @@ from aesara import tensor as at from aesara.graph.basic import ( Apply, + Constant, NominalVariable, Variable, ancestors, @@ -41,10 +42,14 @@ ) from aesara.tensor.type_other import NoneConst from aesara.tensor.var import TensorVariable +from aesara.utils import HashableNDArray from tests import unittest_tools as utt from tests.graph.utils import MyInnerGraphOp +pytestmark = pytest.mark.filterwarnings("error") + + class MyType(Type): def __init__(self, thingy): self.thingy = thingy @@ -84,7 +89,7 @@ def perform(self, *args, **kwargs): raise NotImplementedError("No Python implementation available.") -MyOp = MyOp() +my_op = MyOp() def leaf_formatter(leaf): @@ -107,29 +112,29 @@ def format_graph(inputs, outputs): class TestStr: def test_as_string(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) + node = my_op.make_node(r1, r2) s = format_graph([r1, r2], node.outputs) assert s == ["MyOp(R1, R2)"] def test_as_string_deep(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], r5) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], r5) s = format_graph([r1, r2, r5], node2.outputs) assert s == ["MyOp(MyOp(R1, R2), R5)"] def test_multiple_references(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], node.outputs[0]) assert format_graph([r1, r2, r5], node2.outputs) == [ "MyOp(*1 -> MyOp(R1, R2), *1)" ] def test_cutoff(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], node.outputs[0]) assert format_graph(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] assert format_graph(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] @@ -137,43 +142,27 @@ def test_cutoff(self): class TestClone: def test_accurate(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) - _, new = clone([r1, r2], node.outputs, False) + node = my_op.make_node(r1, r2) + _, new = clone([r1, r2], node.outputs, copy_inputs=False) assert format_graph([r1, r2], new) == ["MyOp(R1, R2)"] def test_copy(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], r5) - _, new = clone([r1, r2, r5], node2.outputs, False) - assert ( - node2.outputs[0].type == new[0].type and node2.outputs[0] is not new[0] - ) # the new output is like the old one but not the same object - assert node2 is not new[0].owner # the new output has a new owner + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], r5) + _, new = clone([r1, r2, r5], node2.outputs, copy_inputs=False) + assert node2.outputs[0].type == new[0].type and node2.outputs[0] is new[0] + assert node2 is new[0].owner assert new[0].owner.inputs[1] is r5 # the inputs are not copied assert ( new[0].owner.inputs[0].type == node.outputs[0].type - and new[0].owner.inputs[0] is not node.outputs[0] - ) # check that we copied deeper too - - def test_not_destructive(self): - # Checks that manipulating a cloned graph leaves the original unchanged. - r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) - _, new = clone([r1, r2, r5], node.outputs, False) - new_node = new[0].owner - new_node.inputs = [MyVariable(7), MyVariable(8)] - assert format_graph(graph_inputs(new_node.outputs), new_node.outputs) == [ - "MyOp(R7, R8)" - ] - assert format_graph(graph_inputs(node.outputs), node.outputs) == [ - "MyOp(MyOp(R1, R2), R5)" - ] + and new[0].owner.inputs[0] is node.outputs[0] + ) def test_constant(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) - _, new = clone([r1, r2, r5], node.outputs, False) + node = my_op.make_node(my_op.make_node(r1, r2).outputs[0], r5) + _, new = clone([r1, r2, r5], node.outputs, copy_inputs=False) new_node = new[0].owner new_node.inputs = [MyVariable(7), MyVariable(8)] c1 = at.constant(1.5) @@ -192,13 +181,13 @@ def test_constant(self): def test_clone_inner_graph(self): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable(4) igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1 = my_op(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) @@ -209,8 +198,8 @@ def test_clone_inner_graph(self): o2_node = o2.owner o2_node_clone = o2_node.clone(clone_inner_graph=True) - assert o2_node_clone is not o2_node - assert o2_node_clone.op.fgraph is not o2_node.op.fgraph + assert o2_node_clone is o2_node + assert o2_node_clone.op.fgraph is o2_node.op.fgraph assert equal_computations( o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs ) @@ -228,9 +217,9 @@ class TestToposort: def test_simple(self): # Test a simple graph r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - o = MyOp(r1, r2) + o = my_op(r1, r2) o.name = "o1" - o2 = MyOp(o, r5) + o2 = my_op(o, r5) o2.name = "o2" clients = {} @@ -257,49 +246,50 @@ def test_simple(self): def test_double_dependencies(self): # Test a graph with double dependencies r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) - o2 = MyOp.make_node(o.outputs[0], r5) + o = my_op.make_node(r1, r1) + o2 = my_op.make_node(o.outputs[0], r5) all = general_toposort(o2.outputs, prenode) assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]] def test_inputs_owners(self): # Test a graph where the inputs have owners r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) + o = my_op.make_node(r1, r1) r2b = o.outputs[0] - o2 = MyOp.make_node(r2b, r2b) + o2 = my_op.make_node(r2b, r2b) all = io_toposort([r2b], o2.outputs) assert all == [o2] - o2 = MyOp.make_node(r2b, r5) + o2 = my_op.make_node(r2b, r5) all = io_toposort([r2b], o2.outputs) assert all == [o2] def test_not_connected(self): # Test a graph which is not connected r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(r3, r4) + o0 = my_op.make_node(r1, r2) + o1 = my_op.make_node(r3, r4) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) assert all == [o1, o0] or all == [o0, o1] def test_io_chain(self): # Test inputs and outputs mixed together in a chain graph r1, r2 = MyVariable(1), MyVariable(2) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(o0.outputs[0], r1) + o0 = my_op.make_node(r1, r2) + o1 = my_op.make_node(o0.outputs[0], r1) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) assert all == [o1] def test_outputs_clients(self): # Test when outputs have clients r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - MyOp.make_node(o0.outputs[0], r4) + o0 = my_op.make_node(r1, r2) + my_op.make_node(o0.outputs[0], r4) all = io_toposort([], o0.outputs) assert all == [o0] +@pytest.mark.skip(reason="Not finished") class TestEval: def setup_method(self): self.x, self.y = scalars("x", "y") @@ -397,9 +387,9 @@ def test_equal_computations(): def test_walk(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" def expand(r): @@ -428,9 +418,9 @@ def expand(r): def test_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = ancestors([o2], blockers=None) @@ -450,9 +440,9 @@ def test_ancestors(): def test_graph_inputs(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = graph_inputs([o2], blockers=None) @@ -463,9 +453,9 @@ def test_graph_inputs(): def test_variables_and_orphans(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" vars_res = vars_between([r1, r2], [o2]) @@ -480,11 +470,11 @@ def test_variables_and_orphans(): def test_ops(): r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, r4) + o2 = my_op(r3, r4) o2.name = "o2" - o3 = MyOp(r3, o1, o2) + o3 = my_op(r3, o1, o2) o3.name = "o3" res = applys_between([r1, r2], [o3]) @@ -495,9 +485,9 @@ def test_ops(): def test_list_of_nodes(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = list_of_nodes([r1, r2], [o2]) @@ -507,9 +497,9 @@ def test_list_of_nodes(): def test_is_in_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" assert is_in_ancestors(o2.owner, o1.owner) @@ -528,13 +518,13 @@ def test_view_roots(): def test_get_var_by_name(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable(4) igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1 = my_op(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) @@ -667,6 +657,7 @@ def test_cloning_replace_not_strict_not_copy_inputs(self): assert x not in f2_inp assert y2 not in f2_inp + @pytest.mark.skip(reason="Not finished") def test_clone(self): def test(x, y, mention_y): if mention_y: @@ -771,7 +762,7 @@ def test_NominalVariable(): assert repr(nv5) == f"NominalVariable(2, {repr(type3)})" - assert nv5.signature() == (type3, 2) + assert hash(nv5) == hash((type(nv5), 2, type3)) nv5_pkld = pickle.dumps(nv5) nv5_unpkld = pickle.loads(nv5_pkld) @@ -807,5 +798,81 @@ def test_NominalVariable_create_variable_type(): ntv_unpkld = pickle.loads(ntv_pkld) assert type(ntv_unpkld) is type(ntv) - assert ntv_unpkld.equals(ntv) + assert ntv_unpkld == ntv assert ntv_unpkld is ntv + + +def test_Apply_equivalence(): + + type1 = MyType(1) + + in_1 = Variable(type1, None, name="in_1") + in_2 = Variable(type1, None, name="in_2") + out_10 = Variable(type1, None, name="out_10") + out_11 = Variable(type1, None, name="out_11") + out_12 = Variable(type1, None, name="out_12") + + apply_1 = Apply(my_op, [in_1], [out_10]) + apply_2 = Apply(my_op, [in_1], [out_11]) + apply_3 = Apply(my_op, [in_2], [out_12]) + + assert apply_1 is apply_2 + assert apply_1 == apply_2 + assert apply_1 != apply_3 + assert hash(apply_1) == hash(apply_2) + assert hash(apply_1) != hash(apply_3) + + assert apply_1.inputs == apply_2.inputs + + assert apply_1.outputs == [out_10] + assert apply_2.outputs == [out_10] + # Output `Variable`s should be updated when the constructor is called with + # the same inputs but different outputs. + assert out_10.owner is apply_1 + assert out_11.owner is apply_1 + + apply_1_pkl = pickle.dumps(apply_1) + apply_1_2 = pickle.loads(apply_1_pkl) + + assert apply_1.op == apply_1_2.op + assert len(apply_1.inputs) == len(apply_1_2.inputs) + assert len(apply_1.outputs) == len(apply_1_2.outputs) + assert apply_1.inputs[0].type == apply_1_2.inputs[0].type + assert apply_1.inputs[0].name == apply_1_2.inputs[0].name + assert apply_1.outputs[0].type == apply_1_2.outputs[0].type + assert apply_1.outputs[0].name == apply_1_2.outputs[0].name + + +class MyType2(MyType): + def filter(self, value, **kwargs): + value = np.asarray(value).view(HashableNDArray) + value.setflags(write=0) + return value + + +def test_Constant_equivalence(): + type1 = MyType2(1) + x = Constant(type1, 1.0) + y = Constant(type1, 1.0) + + assert x == y + assert x is y + + rng = np.random.default_rng(3209) + a_val = rng.normal(size=(2, 3)) + c_val = rng.normal(size=(2, 3)) + + a = Constant(type1, a_val) + b = Constant(type1, a_val) + c = Constant(type1, c_val) + + assert a == b + assert a is b + assert a != x + assert a != c + + a_pkl = pickle.dumps(a) + a_2 = pickle.loads(a_pkl) + + assert a.type == a_2.type + assert a.data == a_2.data diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 3470284e66..62e2f22b9c 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -44,6 +44,9 @@ def filter(self, data): def __eq__(self, other): return isinstance(other, MyType) + def __hash__(self): + return id(self) + def MyVariable(name): return Variable(MyType(), None, None, name=name) @@ -156,12 +159,16 @@ def test_misc(): @assertFailure_fast def test_aliased_inputs_replacement(): - x, y, z = inputs() + x, *_ = inputs() tv = transpose_view(x) + tv.name = "tv" tvv = transpose_view(tv) + tvv.name = "tvv" sx = sigmoid(x) + sx.name = "sx" e = add_in_place(x, tv) - g = create_fgraph([x, y], [e], False) + e.name = "e" + g = create_fgraph([x], [e], False) assert not g.consistent() g.replace(tv, sx) assert g.consistent() @@ -310,16 +317,48 @@ def test_indirect_2(): @assertFailure_fast def test_long_destroyers_loop(): x, y, z = inputs() - e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) + add_xy = add_in_place(x, y) + add_xy.name = "add_i_xy" + add_yz = add_in_place(y, z) + add_yz.name = "add_i_yz" + add_zx = add(z, x) + add_zx.name = "add_zx" + dot_add_xy_yz = dot(add_xy, add_yz) + dot_add_xy_yz.name = "dot_add_xy_yz" + e = dot(dot_add_xy_yz, add_zx) + e.name = "e" g = create_fgraph([x, y, z], [e]) + + orderings = g.destroy_handler.orderings(g, ordered=False) + exp_orderings = {add_yz.owner: {add_xy}, add_xy.owner: {add_zx}} + assert orderings == exp_orderings + assert g.consistent() + + # This apparently introduces a cycle into the graph? + # That means it should fail validation and revert the replacement. + # TODO FIXME: We need tests that directly confirm the results of the + # functions in `DestroyHandler`, and not these extremely indirect + # integration-like tests that assert almost to nothing about the results + # produced by the code we're testing. + # TODO FIXME: Also, why are we even allowing `FunctionGraph`s to take + # these broken states? A quick cycle check in `FunctionGraph.replace` + # would be a lot better. TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g) + + # When `g` is in its inconsistent state the orderings are as follows: + # {AddInPlace(y, z): {AddInPlace(x, y)}, + # AddInPlace(x, y): {AddInPlace(z, x)}, + # AddInPlace(z, x): {AddInPlace(y, z)}} + + # Make sure the replacement was reverted + assert g.outputs[0].owner.inputs[-1].owner.op == add + + orderings = g.destroy_handler.orderings(g, ordered=False) + assert orderings == exp_orderings + assert g.consistent() - # we don't want to see that! - assert ( - str(g) - != "FunctionGraph(Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x)))" - ) + e2 = dot(dot(add_in_place(x, y), add_in_place(y, z)), add_in_place(z, x)) with pytest.raises(InconsistencyError): create_fgraph(*clone([x, y, z], [e2])) @@ -337,8 +376,8 @@ def test_misc_2(): def test_multi_destroyers(): x, y, z = inputs() - e = add(add_in_place(x, y), add_in_place(x, y)) - with pytest.raises(InconsistencyError): + e = add(add_in_place(x, y), add_in_place(x, z)) + with pytest.raises(InconsistencyError, match="Multiple destroyers of"): create_fgraph([x, y, z], [e]) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index c145a99e24..50a57bef63 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -6,7 +6,7 @@ from typing_extensions import Literal from aesara.configdefaults import config -from aesara.graph.basic import NominalVariable +from aesara.graph.basic import Apply, NominalVariable from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph from aesara.graph.utils import MissingInputError @@ -307,8 +307,11 @@ def test_change_input(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) + var3.name = "var3" var4 = op2(var3, var2) + var4.name = "var4" var5 = op3(var4, var2, var2) + var5.name = "var5" cb_tracker = CallbackTracker() fg = FunctionGraph( [var1, var2], [var3, var5], clone=False, features=[cb_tracker] @@ -345,6 +348,7 @@ def test_change_input(self): old_apply_nodes = set(fg.apply_nodes) old_variables = set(fg.variables) old_var5_clients = list(fg.get_clients(var5)) + old_var5_node = var5.owner # We're replacing with the same variable, so nothing should happen fg.change_node_input(var5.owner, 1, var2) @@ -362,28 +366,40 @@ def test_change_input(self): assert fg.outputs[1].owner == var5.owner assert (var5.owner, 1) not in fg.get_clients(var2) - assert len(cb_tracker.callback_history) == 1 - assert cb_tracker.callback_history[0] == ( - "change_input", - (fg, var5.owner, 1, var2, var1), - {"reason": None}, - ) + assert len(cb_tracker.callback_history) == 3 + assert cb_tracker.callback_history == [ + ("prune", (fg, old_var5_node, None), {}), + ("import", (fg, var5.owner, None), {}), + ( + "change_input", + (fg, old_var5_node, var5.owner, 1, var2, var1), + {"reason": None}, + ), + ] cb_tracker.callback_history.clear() + old_var5_node = var5.owner + # Perform a valid `Apply` node input change that results in a # node removal (i.e. `var4.owner`) fg.change_node_input(var5.owner, 0, var1) assert var5.owner.inputs[0] is var1 - assert not fg.get_clients(var4) + assert var4 not in fg.clients assert var4.owner not in fg.apply_nodes assert var4 not in fg.variables - assert len(cb_tracker.callback_history) == 2 + assert len(cb_tracker.callback_history) == 4 assert cb_tracker.callback_history[0] == ("prune", (fg, var4.owner, None), {}) assert cb_tracker.callback_history[1] == ( + "prune", + (fg, old_var5_node, None), + {}, + ) + assert cb_tracker.callback_history[2] == ("import", (fg, var5.owner, None), {}) + assert cb_tracker.callback_history[3] == ( "change_input", - (fg, var5.owner, 0, var4, var1), + (fg, old_var5_node, var5.owner, 0, var4, var1), {"reason": None}, ) @@ -446,23 +462,32 @@ def test_replace(self): assert len(cb_tracker.callback_history) == 0 + old_var4_node = var4.owner + # Test a basic replacement fg.replace_all([(var3, var1)]) assert var3 not in fg.variables + assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var1, var2] assert fg.outputs == [var1, var5] - assert len(cb_tracker.callback_history) == 3 + assert len(cb_tracker.callback_history) == 5 assert cb_tracker.callback_history[0] == ( "change_input", - (fg, "output", 0, var3, var1), + (fg, "output", "output", 0, var3, var1), {"reason": None}, ) assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) assert cb_tracker.callback_history[2] == ( + "prune", + (fg, old_var4_node, None), + {}, + ) + assert cb_tracker.callback_history[3] == ("import", (fg, var4.owner, None), {}) + assert cb_tracker.callback_history[4] == ( "change_input", - (fg, var4.owner, 0, var3, var1), + (fg, old_var4_node, var4.owner, 0, var3, var1), {"reason": None}, ) @@ -472,6 +497,16 @@ def test_replace(self): cb_tracker = CallbackTracker() fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_var5_node = var5.owner + # Test a replacement that would remove the replacement variable # (i.e. `var3`) from the graph when the variable to be replaced # (i.e. `var4`) is removed @@ -483,12 +518,14 @@ def test_replace(self): assert fg.variables == {var1, var3, var5} assert cb_tracker.callback_history == [ - ("attach", (fg,), {}), - ("import", (fg, var3.owner, "init"), {}), - ("import", (fg, var4.owner, "init"), {}), - ("import", (fg, var5.owner, "init"), {}), ("prune", (fg, var4.owner, None), {}), - ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ("prune", (fg, old_var5_node, None), {}), + ("import", (fg, var5.owner, None), {}), + ( + "change_input", + (fg, old_var5_node, var5.owner, 0, var4, var3), + {"reason": None}, + ), ] var3 = op1(var1) @@ -497,6 +534,16 @@ def test_replace(self): cb_tracker = CallbackTracker() fg = FunctionGraph([var1], [var5], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, var3.owner, "init"), {}), + ("import", (fg, var4.owner, "init"), {}), + ("import", (fg, var5.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_var5_node = var5.owner + # Test multiple `change_node_input` calls on the same node fg.replace_all([(var4, var3)]) @@ -505,14 +552,24 @@ def test_replace(self): assert fg.outputs == [var5] assert fg.variables == {var1, var3, var5} + tmp_var5_node = Apply(op3, [var3, var4], [MyVariable("var5_tmp")]) + assert cb_tracker.callback_history == [ - ("attach", (fg,), {}), - ("import", (fg, var3.owner, "init"), {}), - ("import", (fg, var4.owner, "init"), {}), - ("import", (fg, var5.owner, "init"), {}), - ("change_input", (fg, var5.owner, 0, var4, var3), {"reason": None}), + ("prune", (fg, old_var5_node, None), {}), + ("import", (fg, tmp_var5_node, None), {}), + ( + "change_input", + (fg, old_var5_node, tmp_var5_node, 0, var4, var3), + {"reason": None}, + ), ("prune", (fg, var4.owner, None), {}), - ("change_input", (fg, var5.owner, 1, var4, var3), {"reason": None}), + ("prune", (fg, tmp_var5_node, None), {}), + ("import", (fg, var5.owner, None), {}), + ( + "change_input", + (fg, tmp_var5_node, var5.owner, 1, var4, var3), + {"reason": None}, + ), ] def test_replace_outputs(self): @@ -535,9 +592,9 @@ def test_replace_outputs(self): ("attach", (fg,), {}), ("import", (fg, var3.owner, "init"), {}), ("import", (fg, var4.owner, "init"), {}), - ("change_input", (fg, "output", 0, var3, var1), {"reason": None}), + ("change_input", (fg, "output", "output", 0, var3, var1), {"reason": None}), ("prune", (fg, var3.owner, None), {}), - ("change_input", (fg, "output", 2, var3, var1), {"reason": None}), + ("change_input", (fg, "output", "output", 2, var3, var1), {"reason": None}), ] def test_replace_contract(self): @@ -555,6 +612,18 @@ def test_replace_contract(self): cb_tracker = CallbackTracker() fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_v3_node = v3.owner + old_v4_node = v4.owner + # This replacement should produce a new `Apply` node that's equivalent # to `v2` and try to replace `v3`'s node with that one. In other # words, the replacement creates a new node that's already in the @@ -566,7 +635,7 @@ def test_replace_contract(self): assert fg.clients == { x: [(v1.owner, 0)], v1: [(v3.owner, 0)], - v2: [], + # v2: [], v3: [(v4.owner, 0)], v4: [("output", 0)], } @@ -574,13 +643,9 @@ def test_replace_contract(self): assert v2 not in set(sum((n.outputs for n in fg.apply_nodes), [])) assert cb_tracker.callback_history == [ - ("attach", (fg,), {}), - ("import", (fg, v1.owner, "init"), {}), - ("import", (fg, v2.owner, "init"), {}), - ("import", (fg, v3.owner, "init"), {}), - ("import", (fg, v4.owner, "init"), {}), - ("prune", (fg, v2.owner, None), {}), - ("change_input", (fg, v3.owner, 0, v2, v1), {"reason": None}), + ("prune", (fg, old_v3_node, None), {}), + ("import", (fg, v3.owner, None), {}), + ("change_input", (fg, old_v3_node, v3.owner, 0, v2, v1), {"reason": None}), ] # Let's try the same thing at a different point in the chain @@ -598,6 +663,17 @@ def test_replace_contract(self): cb_tracker = CallbackTracker() fg = FunctionGraph([x], [v4], clone=False, features=[cb_tracker]) + assert cb_tracker.callback_history == [ + ("attach", (fg,), {}), + ("import", (fg, v1.owner, "init"), {}), + ("import", (fg, v2.owner, "init"), {}), + ("import", (fg, v3.owner, "init"), {}), + ("import", (fg, v4.owner, "init"), {}), + ] + cb_tracker.callback_history.clear() + + old_v4_node = v4.owner + fg.replace_all([(v3, v2)]) assert v3 not in fg.variables @@ -605,20 +681,16 @@ def test_replace_contract(self): x: [(v1.owner, 0)], v1: [(v2.owner, 0)], v2: [(v4.owner, 0)], - v3: [], + # v3: [], v4: [("output", 0)], } assert fg.apply_nodes == {v4.owner, v2.owner, v1.owner} assert v3 not in set(sum((n.outputs for n in fg.apply_nodes), [])) exp_res = [ - ("attach", (fg,), {}), - ("import", (fg, v1.owner, "init"), {}), - ("import", (fg, v2.owner, "init"), {}), - ("import", (fg, v3.owner, "init"), {}), - ("import", (fg, v4.owner, "init"), {}), - ("prune", (fg, v3.owner, None), {}), - ("change_input", (fg, v4.owner, 0, v3, v2), {"reason": None}), + ("prune", (fg, old_v4_node, None), {}), + ("import", (fg, v4.owner, None), {}), + ("change_input", (fg, old_v4_node, v4.owner, 0, v3, v2), {"reason": None}), ] assert cb_tracker.callback_history == exp_res @@ -667,25 +739,31 @@ def test_replace_circular(self): ) cb_tracker.callback_history.clear() + old_var4_owner = var4.owner + fg.replace_all([(var3, var4)]) # The following works (and is kind of gross), because `var4` has been # mutated in-place - assert fg.apply_nodes == {var4.owner, var5.owner} assert var4.owner.inputs == [var4, var2] + assert fg.apply_nodes == {var4.owner, var5.owner} + assert fg.outputs == [var4, var5] - assert len(cb_tracker.callback_history) == 3 - assert cb_tracker.callback_history[0] == ( - "change_input", - (fg, "output", 0, var3, var4), - {"reason": None}, - ) - assert cb_tracker.callback_history[1] == ("prune", (fg, var3.owner, None), {}) - assert cb_tracker.callback_history[2] == ( - "change_input", - (fg, var4.owner, 0, var3, var4), - {"reason": None}, - ) + assert cb_tracker.callback_history == [ + ( + "change_input", + (fg, "output", "output", 0, var3, var4), + {"reason": None}, + ), + ("prune", (fg, var3.owner, None), {}), + ("prune", (fg, old_var4_owner, None), {}), + ("import", (fg, var4.owner, None), {}), + ( + "change_input", + (fg, old_var4_owner, var4.owner, 0, var3, var4), + {"reason": None}, + ), + ] def test_replace_bad_state(self): @@ -708,8 +786,11 @@ def test_check_integrity(self): var1 = MyVariable("var1") var2 = MyVariable("var2") var3 = op1(var2, var1) + var3.name = "var3" var4 = op2(var3, var2) + var4.name = "var4" var5 = op3(var4, var2, var2) + var5.name = "var5" fg = FunctionGraph([var1, var2], [var3, var5], clone=False) with pytest.raises(Exception, match="The following nodes are .*"): @@ -733,15 +814,24 @@ def test_check_integrity(self): fg.variables.add(var4) with pytest.raises(Exception, match="Undeclared input.*"): - var6 = MyVariable2("var6") - fg.clients[var6] = [(var5.owner, 3)] + var6 = MyVariable("var6") + var7 = op1(var6) + var7.name = "var7" + fg.clients[var6] = [(var7.owner, 0)] fg.variables.add(var6) - var5.owner.inputs.append(var6) + fg.clients[var7] = [("output", 2)] + fg.variables.add(var7) + fg.outputs.append(var7) + fg.apply_nodes.add(var7.owner) fg.check_integrity() fg.variables.remove(var6) - var5.owner.inputs.remove(var6) + fg.variables.remove(var7) + del fg.clients[var6] + del fg.clients[var7] + fg.outputs.remove(var7) + fg.apply_nodes.remove(var7.owner) # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want.