diff --git a/tests/test_utils.py b/tests/test_utils.py index 92367ea..d1bdf0d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,9 @@ -from copy import copy -from types import MethodType - import aesara.tensor as at import numpy as np import pytest +from aesara.graph.basic import Apply from aesara.tensor.exceptions import ShapeError +from aesara.tensor.random.basic import NormalRV from aehmc.utils import RaveledParamsMap @@ -81,12 +80,25 @@ def test_RaveledParamsMap_dtype(): def test_RaveledParamsMap_bad_infer_shape(): - bad_normal_op = copy(at.random.normal) - - def bad_infer_shape(self, *args, **kwargs): - raise ShapeError() - - bad_normal_op.infer_shape = MethodType(bad_infer_shape, bad_normal_op) + class BadNormalRV(NormalRV): + def make_node(self, *args, **kwargs): + res = super().make_node(*args, **kwargs) + # Drop static `Type`-level shape information + rv_out = res.outputs[1] + outputs = [ + res.outputs[0].clone(), + at.tensor(dtype=rv_out.type.dtype, shape=(None,) * rv_out.type.ndim), + ] + return Apply( + self, + res.inputs, + outputs, + ) + + def infer_shape(self, *args, **kwargs): + raise ShapeError() + + bad_normal_op = BadNormalRV() size = (3, 2) beta_rv = bad_normal_op(0, 1, size=size, name="beta")