Skip to content

Commit

Permalink
Merge branch 'main' into add_blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 authored Sep 26, 2022
2 parents 7245341 + ec82b9f commit a57528c
Show file tree
Hide file tree
Showing 246 changed files with 36,455 additions and 33,157 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.9'
- uses: pre-commit/action@v2.0.0

test:
Expand All @@ -72,9 +74,9 @@ jobs:
install-numba: [1]
part:
- "tests --ignore=tests/tensor --ignore=tests/sparse --ignore=tests/tensor/nnet"
- "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_basic_opt.py --ignore=tests/tensor/test_math_opt.py --ignore=tests/tensor/nnet"
- "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/rewriting/test_basic.py --ignore=tests/tensor/rewriting/test_math.py --ignore=tests/tensor/nnet"
- "tests/tensor/test_basic.py tests/tensor/test_math.py tests/tensor/test_math_scipy.py tests/tensor/test_inplace.py"
- "tests/tensor/test_elemwise.py tests/tensor/test_basic_opt.py tests/tensor/test_math_opt.py"
- "tests/tensor/test_elemwise.py tests/tensor/rewriting/test_basic.py tests/tensor/rewriting/test_math.py"
- "tests/tensor/nnet --ignore-glob='*/test_abstract_conv.py'"
- "tests/tensor/nnet/test_abstract_conv.py"
include:
Expand Down Expand Up @@ -143,7 +145,7 @@ jobs:
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax "jaxlib!=0.3.15"
pip install -e ./
mamba list && pip freeze
python -c 'import aesara; print(aesara.config.__str__(print_doc=False))'
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.1.0
rev: v4.3.0
hooks:
- id: debug-statements
exclude: |
(?x)^(
aesara/breakpoint\.py|
aesara/graph/op\.py|
aesara/compile/nanguardmode\.py|
aesara/graph/opt\.py|
aesara/graph/rewriting/basic\.py|
aesara/tensor/var\.py|
)$
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 22.8.0
hooks:
- id: black
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/pycqa/isort
rev: 5.6.4
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/humitos/mirrors-autoflake.git
Expand All @@ -47,7 +47,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
rev: v0.971
hooks:
- id: mypy
additional_dependencies:
Expand Down
30 changes: 24 additions & 6 deletions aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ def disable_log_handler(logger=aesara_logger, handler=logging_default_handler):
raise RuntimeError("You have the aesara directory in your Python path.")

from aesara.configdefaults import config
from aesara.utils import deprecated


change_flags = deprecated("Use aesara.config.change_flags instead!")(
config.change_flags
)


# This is the api version for ops that generate C code. External ops
Expand Down Expand Up @@ -178,3 +172,27 @@ def get_scalar_constant_value(v):
# imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS.
config.warn_unused_flags()

DEPRECATED_NAMES = [
(
"change_flags",
"`aesara.change_flags` is deprecated: use `aesara.config.change_flags` instead.",
config.change_flags,
),
]


def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn

for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object

raise AttributeError(f"module {__name__} has no attribute {name}")
6 changes: 3 additions & 3 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.rewriting.shape import ShapeFeature


def infer_shape(outs, inputs, input_shapes):
Expand Down Expand Up @@ -928,7 +928,7 @@ def perform(self, node, inputs, outputs):
output[0] = variable


@local_optimizer([OpFromGraph])
@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Expand Down
5 changes: 1 addition & 4 deletions aesara/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,8 @@ def clone_inputs(i):
(store_into, update_d[store_into]),
)

# filter_variable ensure smooth conversion of cpu Types
try:
update_val = store_into.type.filter_variable(
update_val, allow_convert=False
)
update_val = store_into.type.filter_variable(update_val, allow_convert=True)
except TypeError:
err_msg = (
"An update must have the same type as the"
Expand Down
62 changes: 29 additions & 33 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""
Driver of graph construction, optimization, and linking.
"""
"""Objects that orchestrate graph construction, rewriting, and linking."""

import copy
import copyreg
Expand Down Expand Up @@ -753,9 +750,8 @@ def checkSV(sv_ori, sv_rpl):
# cause problems.
on_unused_input="ignore",
function_builder=maker.function_builder,
# As this is an optimized graph, it
# can contain inplace. DebugMode check
# that.
# As this is an rewritten graph, it can contain inplace. DebugMode
# check that.
accept_inplace=True,
no_fgraph_prep=True,
).create(input_storage, storage_map=new_storage_map)
Expand Down Expand Up @@ -1182,7 +1178,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
This loop was inserted to remove aliasing between outputs when they all
evaluate to the same value. Originally it was OK for outputs to be aliased,
but some of the outputs can be shared variables, and is not good for shared
variables to be aliased. It might be possible to optimize this by making
variables to be aliased. It might be possible to rewrite this by making
sure there is no aliasing only between shared variables.
If some outputs are constant, we add deep copy to respect the memory
Expand Down Expand Up @@ -1279,7 +1275,7 @@ class FunctionMaker:
"""
`FunctionMaker` is the class to `create` `Function` instances.
This class has the fgraph, the optimizer, and the linker. When
This class has the fgraph, the rewriter, and the linker. When
copying a `Function`, there is no need to duplicate the
`FunctionMaker` instance. Deepcopy still copies both, which can
variable in re-compilation.
Expand All @@ -1292,7 +1288,7 @@ class FunctionMaker:
functions produced by FunctionMaker will return their output value
directly.
mode : Mode instance
Telling FunctionMaker how to optimize and link. None means to use the
Telling FunctionMaker how to rewrite and link. None means to use the
`config.mode`.
accept_inplace : bool
True iff it is acceptable to have inplace operations in the graph from
Expand Down Expand Up @@ -1395,44 +1391,44 @@ def check_unused_inputs(inputs, outputs, on_unused_input):

@staticmethod
def prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile
inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile
):

try:
start_optimizer = time.time()
start_rewriter = time.time()

optimizer_profile = None
opt_time = None
rewriter_profile = None
rewrite_time = None

with config.change_flags(
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
optimizer_profile = optimizer(fgraph)
rewriter_profile = rewriter(fgraph)

end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
_logger.debug(f"Optimizing took {opt_time:f} seconds")
end_rewriter = time.time()
rewrite_time = end_rewriter - start_rewriter
_logger.debug(f"Rewriting took {rewrite_time:f} seconds")

# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:

# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
# If the rewriter got interrupted
if rewrite_time is None:
end_rewriter = time.time()
rewrite_time = end_rewriter - start_rewriter

aesara.compile.profiling.total_graph_opt_time += opt_time
aesara.compile.profiling.total_graph_rewrite_time += rewrite_time

if profile:
if optimizer_profile is None and hasattr(optimizer, "pre_profile"):
optimizer_profile = optimizer.pre_profile
if rewriter_profile is None and hasattr(rewriter, "pre_profile"):
rewriter_profile = rewriter.pre_profile

profile.optimizer_time += opt_time
profile.rewriting_time += rewrite_time

if config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
profile.rewriter_profile = (rewriter, rewriter_profile)
elif config.profile_optimizer and profile is not False:
# If False, it means the profiling for that function was
# explicitly disabled
Expand Down Expand Up @@ -1466,8 +1462,8 @@ def __init__(
):
# Save the provided mode, not the instantiated mode.
# The instantiated mode don't pickle and if we unpickle an Aesara
# function and it get re-compiled, we want the current optimizer to be
# used, not the optimizer when it was saved.
# function and it get re-compiled, we want the current rewriter to be
# used, not the rewriter when it was saved.
self.mode = mode
mode = aesara.compile.mode.get_mode(mode)

Expand All @@ -1478,7 +1474,7 @@ def __init__(
if profile:
# This is very important:
# 1) We preload the cache here to not have its timing
# included in optimization that compile function.
# included with the rewrites.
# 2) Do not refresh the cache here by default. It cause
# too much execution time during testing as we compile
# much more functions then the number of compile c
Expand Down Expand Up @@ -1515,11 +1511,11 @@ def __init__(

self.fgraph = fgraph

optimizer, linker = mode.optimizer, copy.copy(mode.linker)
rewriter, linker = mode.optimizer, copy.copy(mode.linker)

if not no_fgraph_prep:
self.prepare_fgraph(
inputs, outputs, found_updates, fgraph, optimizer, linker, profile
inputs, outputs, found_updates, fgraph, rewriter, linker, profile
)

assert len(fgraph.outputs) == len(outputs + found_updates)
Expand Down Expand Up @@ -1715,7 +1711,7 @@ def orig_function(
time spent in this function.
accept_inplace : bool
True iff the graph can contain inplace operations prior to the
optimization phase (default is False).
rewrite phase (default is False).
profile : None or ProfileStats instance
on_unused_input : {'raise', 'warn', 'ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph.
Expand Down
Loading

0 comments on commit a57528c

Please sign in to comment.