diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 2bf776d017..919190f9aa 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1119,6 +1119,34 @@ def test_test_values(self, test_value): f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] ) + def test_not_fusing_broadcasted_subgraphs(self): + # There are some cases in self.test_elemwise_fusion, but this test + # confirms that the fused subgraphs are exactly the expected ones + xs = scalar("xm") + xm = matrix("xs") + + es = log(xs + 5) + em = exp(xm * 5) + esm = es - em + + f = aesara.function([xs, xm], esm, mode=self.mode) + apply_nodes = f.maker.fgraph.toposort() + assert len(apply_nodes) == 3 + assert isinstance(apply_nodes[0].op, DimShuffle) + # Inner Scalar output Composite + assert isinstance(apply_nodes[1].op.scalar_op, Composite) + assert {node.op for node in apply_nodes[1].op.scalar_op.fgraph.apply_nodes} == { + aes.add, + aes.log, + } + # Outer Matrix output Composite + assert isinstance(apply_nodes[2].op.scalar_op, Composite) + assert {node.op for node in apply_nodes[2].op.scalar_op.fgraph.apply_nodes} == { + aes.sub, + aes.exp, + aes.mul, + } + class TimesN(aes.basic.UnaryScalarOp): """