Skip to content

Commit

Permalink
Add test for nested broadcasted Composite graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo Vieira committed Oct 8, 2022
1 parent 131b541 commit e33260d
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,34 @@ def test_test_values(self, test_value):
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)

def test_not_fusing_broadcasted_subgraphs(self):
# There are some cases in self.test_elemwise_fusion, but this test
# confirms that the fused subgraphs are exactly the expected ones
xs = scalar("xm")
xm = matrix("xs")

es = log(xs + 5)
em = exp(xm * 5)
esm = es - em

f = aesara.function([xs, xm], esm, mode=self.mode)
apply_nodes = f.maker.fgraph.toposort()
assert len(apply_nodes) == 3
assert isinstance(apply_nodes[0].op, DimShuffle)
# Inner Scalar output Composite
assert isinstance(apply_nodes[1].op.scalar_op, Composite)
assert {node.op for node in apply_nodes[1].op.scalar_op.fgraph.apply_nodes} == {
aes.add,
aes.log,
}
# Outer Matrix output Composite
assert isinstance(apply_nodes[2].op.scalar_op, Composite)
assert {node.op for node in apply_nodes[2].op.scalar_op.fgraph.apply_nodes} == {
aes.sub,
aes.exp,
aes.mul,
}


class TimesN(aes.basic.UnaryScalarOp):
"""
Expand Down

0 comments on commit e33260d

Please sign in to comment.