Skip to content

Commit

Permalink
Fuse consecutive Elemwise subgraphs with multiple outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo Vieira committed Oct 6, 2022
1 parent 6c6bf08 commit 074f125
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 55 deletions.
229 changes: 195 additions & 34 deletions aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import logging
import sys
import time
import traceback
from collections import defaultdict
from io import StringIO
from typing import Optional
from itertools import chain
from typing import List, Optional, Tuple

import numpy as np

Expand All @@ -18,8 +18,10 @@
from aesara.graph.basic import (
Apply,
Constant,
Node,
Variable,
ancestors,
clone_replace,
equal_computations,
io_toposort,
)
Expand Down Expand Up @@ -3116,11 +3118,7 @@ def elemwise_max_input_fct(node):


class FusionOptimizer(GlobalOptimizer):
"""Graph optimizer that simply runs local fusion operations.
TODO: This is basically a `EquilibriumOptimizer`; we should just use that.
"""
"""Graph optimizer that fuses consecutive Elemwise operations."""

def __init__(self, local_optimizer):
super().__init__()
Expand All @@ -3129,38 +3127,199 @@ def __init__(self, local_optimizer):
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())

def find_fuseable_subgraphs(
self, fg: FunctionGraph
) -> List[Tuple[List[Variable], List[Variable]]]:
"""Find all subgraphs in a FunctionGraph that can be fused together
Returns
-------
List of independent subgraphs inputs and outputs
"""

def find_leaf_elemwise_vars(node: Node):
# Only consider nodes with single outputs
if len(node.outputs) != 1:
return []

if isinstance(node.op, Elemwise):
return [node.outputs[0]]

upstream_root_elemwise_vars = (
find_leaf_elemwise_vars(inp.owner)
for inp in node.inputs
if inp.owner is not None
)
# Flatten root variables
return list(chain.from_iterable(upstream_root_elemwise_vars))

def find_root_consecutive_elemwise_vars(node: Node):
# TODO: Do not walk across broadcastad elemwises
# TODO: Add special C-code check
root_elemwise_vars = []
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and len(inp.owner.outputs) == 1
# Do not merge Elemwise Ops that don't have the same
# broadcastable pattern to avoid duplicated computations
and inp.type.broadcastable == node.outputs[0].type.broadcastable
):
# Try further upstream
root_elemwise_vars.extend(
find_root_consecutive_elemwise_vars(inp.owner)
)
else:
root_elemwise_vars.append(inp)
return root_elemwise_vars

# aesara.dprint(fg)
elemwise_outputs = set()
for out in fg.outputs:
if out.owner is not None:
elemwise_outputs.update(find_leaf_elemwise_vars(out.owner))

# print(f"{elemwise_outputs=}")
if not elemwise_outputs:
return []

elemwise_inputs: dict = {
out: find_root_consecutive_elemwise_vars(out.owner)
for out in elemwise_outputs
}
# print(f"{elemwise_inputs=}")

# Filter out isolated elemwise nodes
elemwise_outputs = [
out for out in elemwise_outputs if elemwise_inputs[out] != out.owner.inputs
]
# print(f"{elemwise_outputs=}")
if not elemwise_outputs:
return []

# Separate subgraphs that share no inputs whatsoever
disjoint_elemwise_outputs = [[elemwise_outputs.pop(0)]]
for next_out in elemwise_outputs:
disjoint = True
for prev_outs in disjoint_elemwise_outputs:
for prev_out in prev_outs:
if any(
set(elemwise_inputs[next_out]) & set(elemwise_inputs[prev_out])
):
prev_outs.append(next_out)
disjoint = False
break
if not disjoint:
break
if disjoint:
disjoint_elemwise_outputs.append([next_out])
# print(f"{disjoint_elemwise_outputs=}")

disjoint_elemwise_inputs = []
for outs in disjoint_elemwise_outputs:
inps = []
for out in outs:
for inp in elemwise_inputs[out]:
if inp not in inps:
inps.append(inp)
disjoint_elemwise_inputs.append(inps)
# print(f"{disjoint_elemwise_inputs=}")

fuseable_subgraphs = [
(inps, outs)
for inps, outs in zip(disjoint_elemwise_inputs, disjoint_elemwise_outputs)
]

# Call function in the inputs
inputs = []
for inps in elemwise_inputs.values():
for inp in inps:
if inp not in inputs:
inputs.append(inp)
# print(f"{inputs=}")
# print(" ")
upstream_fg = FunctionGraph(outputs=inputs, clone=False)
fuseable_subgraphs.extend(self.find_fuseable_subgraphs(upstream_fg))

return fuseable_subgraphs

def elemwise_to_scalar(self, inputs, outputs):
replace_inputs = [(inp, inp.type()) for inp in inputs]
outputs = clone_replace(outputs, replace=replace_inputs)

inputs = [inp for _, inp in replace_inputs]
fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False)
middle_inputs = []

scalar_inputs = [
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
]
middle_scalar_inputs = []

# print(f"{fg.toposort()=}")
for node in fg.toposort():
node_scalar_inputs = []
for inp in node.inputs:
if inp in inputs:
node_scalar_inputs.append(scalar_inputs[inputs.index(inp)])
elif inp in middle_inputs:
node_scalar_inputs.append(
middle_scalar_inputs[middle_inputs.index(inp)]
)
else:
new_scalar_input = aes.get_scalar_type(
inp.type.dtype
).make_variable()
node_scalar_inputs.append(new_scalar_input)
middle_scalar_inputs.append(new_scalar_input)
middle_inputs.append(inp)

new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs)
middle_scalar_inputs.append(new_scalar_node.outputs[0])
middle_inputs.append(node.outputs[0])

scalar_outputs = [
middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs
]
return scalar_inputs, scalar_outputs

def apply(self, fgraph):
did_something = True
nb_iter = 0
nb_replacement = 0
nb_inconsistency_replace = 0
time_toposort = 0

if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
while did_something:
t0 = time.time()
nodelist = list(fgraph.toposort())
time_toposort += time.time() - t0
nodelist.reverse()
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes:
new_outputs = self.optimizer(fgraph, node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
fgraph.replace_all_validate(
list(zip(node.outputs, new_outputs)),
reason=self.__class__.__name__,
)
did_something = True
nb_replacement += 1
except InconsistencyError:
nb_inconsistency_replace += 1
nb_iter += 1

max_inputs = elemwise_max_input_fct(None)
for inputs, outputs in self.find_fuseable_subgraphs(fgraph):
# TODO: If we care about Python mode, we should try to fuse the
# largest possible subgraphs based on number of inputs, instead
# of just failing like we used to do before
if len(inputs) > max_inputs:
_logger.warning(
"Loop fusion failed because the resulting node would exceed "
"the kernel argument limit."
)
continue
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))(
*inputs
)
if not isinstance(composite_outputs, list):
composite_outputs = [composite_outputs]

try:
# print(f"{outputs=}, {composite_outputs=}")
fgraph.replace_all_validate(
list(zip(outputs, composite_outputs)),
reason=self.__class__.__name__,
)
nb_replacement += 1
except InconsistencyError:
nb_inconsistency_replace += 1

if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
Expand All @@ -3175,19 +3334,21 @@ def apply(self, fgraph):
validate_time = None
callback_time = None
callbacks_time = {}

return (
self,
nb_iter,
1, # nb_iter
nb_replacement,
nb_inconsistency_replace,
validate_time,
callback_time,
callbacks_time,
time_toposort,
0, # toposort_time
)

@staticmethod
def print_profile(stream, prof, level=0):
# TODO: Update this
blanc = " " * level
print(blanc, "FusionOptimizer", file=stream)
print(blanc, " nb_iter", prof[1], file=stream)
Expand Down
44 changes: 23 additions & 21 deletions tests/tensor/test_basic_opt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import copy

import numpy as np
Expand Down Expand Up @@ -1157,11 +1156,8 @@ def impl(self, x):

@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""Make sure that `local_elemwise_fusion_op` uses test values correctly
when they have zero dimensions.
"""

opts = OptimizationQuery(
Expand All @@ -1181,27 +1177,33 @@ def test_test_values(self, test_value):
y.tag.test_value = test_value
z.tag.test_value = test_value

if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()

with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=mode)
f = function([x, y, z], out, mode=mode)

if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1

x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value,
np.full_like(test_value, 2.0),
)

def test_multiple_outputs(self):
x = vector("x")
y = exp(x / 4)
w = y * 2
z = y + 2

f = aesara.function([x], [w, z])
aesara.dprint(f)
assert len(f.maker.fgraph.apply_nodes) == 1
r = f([0, 0])
assert np.allclose(r[0], [2, 2])
assert np.allclose(r[1], [3, 3])


class TimesN(aes.basic.UnaryScalarOp):
Expand Down

0 comments on commit 074f125

Please sign in to comment.