Skip to content

Commit

Permalink
Move basic rewriting code to aesara.graph.rewriting
Browse files Browse the repository at this point in the history
- `aesara.graph.opt` has been changed to `aesara.graph.rewriting.basic`
- `aesara.graph.opt_utils` has been changed to `aesara.graph.rewriting.utils`
- `aesara.graph.optdb` has been changed to `aesara.graph.rewriting.db`
- `aesara.graph.unify` has been changed to `aesara.graph.rewriting.unify`
- `aesara.graph.kanren` has been changed to `aesara.graph.rewriting.kanren`

The tests associated with each module have been updated accordingly.
  • Loading branch information
brandonwillard committed Aug 17, 2022
1 parent 746eecb commit 7550668
Show file tree
Hide file tree
Showing 66 changed files with 146 additions and 134 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
aesara/breakpoint\.py|
aesara/graph/op\.py|
aesara/compile/nanguardmode\.py|
aesara/graph/opt\.py|
aesara/graph/rewriting/basic\.py|
aesara/tensor/var\.py|
)$
- id: check-merge-conflict
Expand Down
2 changes: 1 addition & 1 deletion aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, node_rewriter
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature

Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from aesara.compile.function.types import Supervisor
from aesara.configdefaults import config
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.opt import (
from aesara.graph.rewriting.basic import (
CheckStackTraceRewriter,
GraphRewriter,
MergeOptimizer,
NodeProcessingGraphRewriter,
)
from aesara.graph.optdb import (
from aesara.graph.rewriting.db import (
EquilibriumDB,
LocalGroupDB,
RewriteDatabase,
Expand Down
6 changes: 3 additions & 3 deletions aesara/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import node_rewriter, graph_rewriter
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.rewriting.basic import node_rewriter, graph_rewriter
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.rewriting.db import RewriteDatabaseQuery

# isort: on
Empty file.
6 changes: 3 additions & 3 deletions aesara/graph/opt.py → aesara/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@


if TYPE_CHECKING:
from aesara.graph.unify import Var
from aesara.graph.rewriting.unify import Var


_logger = logging.getLogger("aesara.graph.opt")
_logger = logging.getLogger("aesara.graph.rewriting.basic")

RemoveKeyType = Literal["remove"]
TransformOutputType = Union[
Expand Down Expand Up @@ -1586,7 +1586,7 @@ def __init__(
often.
"""
from aesara.graph.unify import convert_strs_to_vars
from aesara.graph.rewriting.unify import convert_strs_to_vars

var_map: Dict[str, "Var"] = {}
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
Expand Down
2 changes: 1 addition & 1 deletion aesara/graph/optdb.py → aesara/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union

from aesara.configdefaults import config
from aesara.graph import opt as aesara_rewriting
from aesara.graph.rewriting import basic as aesara_rewriting
from aesara.misc.ordered_set import OrderedSet
from aesara.utils import DefaultOrderedDict

Expand Down
6 changes: 3 additions & 3 deletions aesara/graph/kanren.py → aesara/graph/rewriting/kanren.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from unification.variable import Var

from aesara.graph.basic import Apply, Variable
from aesara.graph.opt import NodeRewriter
from aesara.graph.unify import eval_if_etuple
from aesara.graph.rewriting.basic import NodeRewriter
from aesara.graph.rewriting.unify import eval_if_etuple


class KanrenRelationSub(NodeRewriter):
Expand All @@ -24,7 +24,7 @@ class KanrenRelationSub(NodeRewriter):
from kanren import eq, conso, var
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.rewriting.kanren import KanrenRelationSub
def relation(in_lv, out_lv):
Expand Down
2 changes: 1 addition & 1 deletion aesara/graph/unify.py → aesara/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __str__(self):
return f"~{self.token} [{self.constraint}]"

def __repr__(self):
return f"ConstrainedVar({repr(self.constraint)}, {self.token})"
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"


def car_Variable(x):
Expand Down
6 changes: 3 additions & 3 deletions aesara/graph/opt_utils.py → aesara/graph/rewriting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
vars_between,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import RewriteDatabaseQuery
from aesara.graph.rewriting.db import RewriteDatabaseQuery


if TYPE_CHECKING:
from aesara.graph.opt import GraphRewriter
from aesara.graph.rewriting.basic import GraphRewriter


def rewrite_graph(
Expand Down Expand Up @@ -89,7 +89,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
See help on `aesara.graph.basic.is_same_graph` for additional documentation.
"""
from aesara.graph.opt import MergeOptimizer
from aesara.graph.rewriting.basic import MergeOptimizer

if givens is None:
givens = {}
Expand Down
7 changes: 3 additions & 4 deletions aesara/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ def simple_extract_stack(
if len(trace) == 0:
rm = False
for p in skips:
# Julian: I added the 'tests' exception together with
# Arnaud. Otherwise, we'd lose the stack trace during
# in our test cases (e.g. in test_opt.py). We're not
# sure this is the right way to do it though.
# The 'tests' exception was added; otherwise, we'd lose the
# stack trace during in our test cases. We're not sure this is
# the right way to do it, though.
if p in filename and "tests" not in filename:
rm = True
break
Expand Down
8 changes: 4 additions & 4 deletions aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GraphRewriter, in2out, node_rewriter
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast

Expand Down Expand Up @@ -435,8 +435,8 @@ def cond_make_inplace(fgraph, node):
# XXX: Optimizations commented pending further debugging (certain optimizations
# make computation less lazy than it should be currently).
#
# ifelse_equilibrium = graph.optdb.EquilibriumDB()
# ifelse_seqopt = graph.optdb.SequenceDB()
# ifelse_equilibrium = graph.rewriting.db.EquilibriumDB()
# ifelse_seqopt = graph.rewriting.db.SequenceDB()
# ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run',
# 'ifelse')
""" Comments:
Expand Down Expand Up @@ -738,7 +738,7 @@ def cond_merge_random_op(fgraph, main_node):
# XXX: Optimizations commented pending further debugging (certain optimizations
# make computation less lazy than it should be currently).
#
# pushout_equilibrium = graph.optdb.EquilibriumDB()
# pushout_equilibrium = graph.rewriting.db.EquilibriumDB()
#
# XXX: This optimization doesn't seem to exist anymore?
# pushout_equilibrium.register("cond_lift_single_if",
Expand Down
2 changes: 1 addition & 1 deletion aesara/sandbox/linalg/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from aesara.graph.opt import node_rewriter
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor import basic as at
from aesara.tensor.basic_opt import (
register_canonicalize,
Expand Down
2 changes: 1 addition & 1 deletion aesara/sandbox/rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from aesara.configdefaults import config
from aesara.gradient import undefined_grad
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.opt import in2out, node_rewriter
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.link.c.op import COp, Op
from aesara.link.c.params_type import ParamsType
from aesara.sandbox import multinomial
Expand Down
2 changes: 1 addition & 1 deletion aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from aesara.gradient import DisconnectedType, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import MergeOptimizer
from aesara.graph.rewriting.basic import MergeOptimizer
from aesara.graph.type import HasDataType, HasShape
from aesara.graph.utils import MetaObject, MethodNotDefined
from aesara.link.c.op import COp
Expand Down
4 changes: 2 additions & 2 deletions aesara/scan/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GraphRewriter, in2out, node_rewriter
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.graph.rewriting.db import EquilibriumDB, SequenceDB
from aesara.graph.type import HasShape
from aesara.graph.utils import InconsistencyError
from aesara.scan.op import Scan, ScanInfo
Expand Down
6 changes: 5 additions & 1 deletion aesara/sparse/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import aesara.scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.opt import PatternNodeRewriter, WalkingGraphRewriter, node_rewriter
from aesara.graph.rewriting.basic import (
PatternNodeRewriter,
WalkingGraphRewriter,
node_rewriter,
)
from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
Expand Down
6 changes: 3 additions & 3 deletions aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.opt import (
from aesara.graph.rewriting.basic import (
GraphRewriter,
NodeRewriter,
RemovalNodeRewriter,
Expand All @@ -36,7 +36,7 @@
in2out,
node_rewriter,
)
from aesara.graph.optdb import RewriteDatabase, SequenceDB
from aesara.graph.rewriting.db import RewriteDatabase, SequenceDB
from aesara.graph.utils import (
InconsistencyError,
MethodNotDefined,
Expand Down Expand Up @@ -1433,7 +1433,7 @@ def same_shape(
clone=True,
# copy_inputs=False,
)
from aesara.graph.opt_utils import rewrite_graph
from aesara.graph.rewriting.utils import rewrite_graph

canon_shapes = rewrite_graph(
shapes_fg, custom_rewrite=topo_constant_folding
Expand Down
6 changes: 3 additions & 3 deletions aesara/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@
from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import Op
from aesara.graph.opt import (
from aesara.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.graph.optdb import SequenceDB
from aesara.graph.rewriting.db import SequenceDB
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
Expand Down Expand Up @@ -1526,7 +1526,7 @@ def on_import(new_node):
if new_node is not node:
nodelist.append(new_node)

u = aesara.graph.opt.DispatchingFeature(
u = aesara.graph.rewriting.basic.DispatchingFeature(
on_import, None, None, name="GemmOptimizer"
)
fgraph.attach_feature(u)
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/blas_c.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from aesara.configdefaults import config
from aesara.graph.opt import in2out
from aesara.graph.rewriting.basic import in2out
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.scalar import bool as bool_t
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/blas_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from aesara.graph.opt import in2out
from aesara.graph.rewriting.basic import in2out
from aesara.tensor.blas import (
Ger,
blas_optdb,
Expand Down
6 changes: 3 additions & 3 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import aesara.scalar.basic as aes
import aesara.scalar.math as aes_math
from aesara.graph.basic import Constant, Variable
from aesara.graph.opt import (
from aesara.graph.rewriting.basic import (
NodeRewriter,
PatternNodeRewriter,
SequentialNodeRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.graph.opt_utils import get_clients_at_depth
from aesara.graph.rewriting.utils import get_clients_at_depth
from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import assert_op
from aesara.tensor.basic import (
Expand Down Expand Up @@ -1941,7 +1941,7 @@ def local_pow_specialize_device(fgraph, node):

# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the test_opt.py:test_log_erfc test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/nnet/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, graph_rewriter, node_rewriter
from aesara.graph.rewriting.basic import copy_stack_trace, graph_rewriter, node_rewriter
from aesara.link.c.op import COp
from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp
Expand Down
4 changes: 2 additions & 2 deletions aesara/tensor/nnet/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, node_rewriter
from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div
from aesara.tensor import basic as at
from aesara.tensor.basic import as_tensor_variable
Expand Down Expand Up @@ -896,7 +896,7 @@ def local_abstract_batch_norm_inference(fgraph, node):


# Register Cpu Optimization
bn_groupopt = aesara.graph.optdb.LocalGroupDB()
bn_groupopt = aesara.graph.rewriting.db.LocalGroupDB()
bn_groupopt.__name__ = "batchnorm_opts"
register_specialize_device(bn_groupopt, "fast_compile", "fast_run")

Expand Down
6 changes: 5 additions & 1 deletion aesara/tensor/nnet/conv3d2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.graph.opt import WalkingGraphRewriter, copy_stack_trace, node_rewriter
from aesara.graph.rewriting.basic import (
WalkingGraphRewriter,
copy_stack_trace,
node_rewriter,
)


def get_diagonal_subtensor_view(x, i0, i1):
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/nnet/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aesara.configdefaults import config
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply
from aesara.graph.opt import node_rewriter
from aesara.graph.rewriting.basic import node_rewriter
from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize
Expand Down
4 changes: 2 additions & 2 deletions aesara/tensor/nnet/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aesara import compile
from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.graph.opt import (
from aesara.graph.rewriting.basic import (
MetaNodeRewriterSkip,
WalkingGraphRewriter,
copy_stack_trace,
Expand Down Expand Up @@ -486,7 +486,7 @@ def local_conv2d_gradinputs_cpu(fgraph, node):


# Register Cpu Optimization
conv_groupopt = aesara.graph.optdb.LocalGroupDB()
conv_groupopt = aesara.graph.rewriting.db.LocalGroupDB()
conv_groupopt.__name__ = "conv_opts"
register_specialize_device(conv_groupopt, "fast_compile", "fast_run")

Expand Down
Loading

0 comments on commit 7550668

Please sign in to comment.