Skip to content

Commit

Permalink
Hash-cons Apply, Constant and change node input replacement semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 14, 2022
1 parent 3c665a5 commit 3e9665c
Show file tree
Hide file tree
Showing 19 changed files with 1,034 additions and 765 deletions.
117 changes: 57 additions & 60 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from itertools import chain
from itertools import product as itertools_product
from logging import Logger
from typing import Optional
from typing import TYPE_CHECKING, Optional, Union
from warnings import warn

import numpy as np
from typing_extensions import Literal

import aesara
from aesara.compile.function.types import (
Expand All @@ -42,7 +43,9 @@
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function


__docformat__ = "restructuredtext en"
if TYPE_CHECKING:
from aesara.graph.basic import Apply

_logger: Logger = logging.getLogger("aesara.compile.debugmode")
_logger.addFilter(NoDuplicateOptWarningFilter())

Expand Down Expand Up @@ -1109,43 +1112,32 @@ class _FunctionGraphEvent:
"""

kind = ""
"""
One of 'import', 'change', 'prune'.
"""

node = None
"""
Either 'output' or an Apply instance.
"""

op = None
"""Either 'output' or an Op instance"""
kind: Literal["import", "change", "prune"]
old_node: Optional[Union[Literal["output"], "Apply"]]
new_node: Optional[Union[Literal["output"], "Apply"]]
op: Optional[Union[Literal["output"], Op]]
idx: Optional[int]
reason: Optional[str]

idx = None
"""
Change events involve an position index of the input variable.
"""

reason = None
"""
Change events sometimes have a reason.
"""

def __init__(self, kind, node, idx=None, reason=None):
def __init__(
self,
kind: Literal["import", "change", "prune"],
old_node: Union[Literal["output"], "Apply"],
new_node: Union[Literal["output"], "Apply"] = None,
idx: Optional[int] = None,
reason: Optional[str] = None,
):
self.kind = kind
if node == "output":
self.node = "output"
if old_node == "output":
self.old_node = "output"
self.new_node = "output"
self.op = "output"
else:
self.node = node
self.op = node.op
self.old_node = old_node
self.new_node = new_node
self.op = old_node.op
self.idx = idx
self.reason = str(reason)
self.reason = str(reason) if reason else None

def __str__(self):
if self.kind == "change":
Expand Down Expand Up @@ -1219,21 +1211,21 @@ def on_attach(self, fgraph):
self.replaced_by = {}
self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
self.on_import(fgraph, node, reason="on_attach")

def on_detach(self, fgraph):
assert fgraph is self.fgraph
self.fgraph = None

def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
self.event_list.append(_FunctionGraphEvent("prune", node, reason=reason))
assert node in self.active_nodes
assert node not in self.inactive_nodes
self.active_nodes.remove(node)
self.inactive_nodes.add(node)

def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
self.event_list.append(_FunctionGraphEvent("import", node, reason=reason))

assert node not in self.active_nodes
self.active_nodes.add(node)
Expand All @@ -1253,31 +1245,36 @@ def on_import(self, fgraph, node, reason):
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
def on_change_input(
self, fgraph, old_node, new_node, i, old_var, new_var, reason=None
):
reason = str(reason)
self.event_list.append(
_FunctionGraphEvent("change", node, reason=reason, idx=i)
_FunctionGraphEvent("change", old_node, new_node, idx=i, reason=reason)
)

self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
self.on_import(fgraph, new_node, reason=reason)
self.on_prune(fgraph, old_node, reason=reason)

self.reasons.setdefault(new_var, [])
self.replaced_by.setdefault(new_var, [])

append_reason = True
for tup in self.reasons[new_r]:
if tup[0] == reason and tup[1] is r:
for tup in self.reasons[new_var]:
if tup[0] == reason and tup[1] is old_var:
append_reason = False

if append_reason:
# N.B. compute the debugprint now, because future
# optimizations will change the graph
done = dict()
used_ids = dict()
self.reasons[new_r].append(
self.reasons[new_var].append(
(
reason,
r,
old_var,
_debugprint(
r,
old_var,
prefix=" ",
depth=6,
file=StringIO(),
Expand All @@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
used_ids=used_ids,
).getvalue(),
_debugprint(
new_r,
new_var,
prefix=" ",
depth=6,
file=StringIO(),
Expand All @@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
).getvalue(),
)
)
self.replaced_by[r].append((reason, new_r))
self.replaced_by[old_var].append((reason, new_var))

if r in self.equiv:
r_set = self.equiv[r]
if old_var in self.equiv:
r_set = self.equiv[old_var]
else:
r_set = self.equiv.setdefault(r, {r})
self.all_variables_ever.append(r)
r_set = self.equiv.setdefault(old_var, {old_var})
self.all_variables_ever.append(old_var)

if new_r in self.equiv:
new_r_set = self.equiv[new_r]
if new_var in self.equiv:
new_r_set = self.equiv[new_var]
else:
new_r_set = self.equiv.setdefault(new_r, {new_r})
self.all_variables_ever.append(new_r)
new_r_set = self.equiv.setdefault(new_var, {new_var})
self.all_variables_ever.append(new_var)

assert new_r in new_r_set
assert r in r_set
assert new_var in new_r_set
assert old_var in r_set

# update one equivalence set to contain the other
# transfer all the elements of the old one to the new one
Expand All @@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
self.equiv[like_new_r] = r_set
assert like_new_r in r_set

assert self.equiv[r] is r_set
assert self.equiv[new_r] is r_set
assert self.equiv[old_var] is r_set
assert self.equiv[new_var] is r_set

def printstuff(self):
for key in self.equiv:
Expand Down
Loading

0 comments on commit 3e9665c

Please sign in to comment.