From eff0721baeab64f38e42749c0713aab0aaab4e4a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 May 2023 22:41:41 +0200 Subject: [PATCH] Implement Blockwise Op to vectorize existing Ops Inspired by: https://github.com/aesara-devs/aesara/pull/1215 Co-authored-by: Brandon T. Willard Co-authored-by: Purna Chandra Mansingh Co-authored-by: Sayam Kumar Co-authored-by: Kaustubh --- pytensor/tensor/blockwise.py | 413 +++++++++++++++++++++++ pytensor/tensor/elemwise.py | 73 ++-- pytensor/tensor/random/op.py | 29 +- pytensor/tensor/rewriting/__init__.py | 1 + pytensor/tensor/rewriting/blockwise.py | 41 +++ pytensor/tensor/utils.py | 53 +++ tests/tensor/random/test_op.py | 36 ++ tests/tensor/rewriting/test_blockwise.py | 36 ++ tests/tensor/test_blockwise.py | 258 ++++++++++++++ tests/tensor/test_elemwise.py | 70 +++- 10 files changed, 964 insertions(+), 46 deletions(-) create mode 100644 pytensor/tensor/blockwise.py create mode 100644 pytensor/tensor/rewriting/blockwise.py create mode 100644 tests/tensor/rewriting/test_blockwise.py create mode 100644 tests/tensor/test_blockwise.py diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py new file mode 100644 index 0000000000..50f7445ce2 --- /dev/null +++ b/pytensor/tensor/blockwise.py @@ -0,0 +1,413 @@ +import re +from functools import singledispatch +from typing import Any, Dict, List, Optional, Sequence, Tuple, cast + +import numpy as np + +from pytensor import config +from pytensor.gradient import DisconnectedType +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.null_type import NullType +from pytensor.graph.op import Op +from pytensor.tensor import as_tensor_variable +from pytensor.tensor.shape import shape_padleft +from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor +from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string +from pytensor.tensor.var import TensorVariable + + +# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad) + +# Copied verbatim from numpy.lib.function_base +# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029 +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def _parse_gufunc_signature(signature): + """ + Parse string signatures for a generalized universal function. + + Arguments + --------- + signature : string + Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)`` + for ``np.matmul``. + + Returns + ------- + Tuple of input and output core dimensions parsed from the signature, each + of the form List[Tuple[str, ...]]. + """ + signature = re.sub(r"\s+", "", signature) + + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + return tuple( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) + + +def safe_signature( + core_inputs: Sequence[Variable], + core_outputs: Sequence[Variable], +) -> str: + def operand_sig(operand: Variable, prefix: str) -> str: + operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim)) + return f"({operands})" + + inputs_sig = ",".join( + operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs) + ) + outputs_sig = ",".join( + operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs) + ) + return f"{inputs_sig}->{outputs_sig}" + + +@singledispatch +def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply: + if hasattr(op, "gufunc_signature"): + signature = op.gufunc_signature + else: + # TODO: This is pretty bad for shape inference and merge optimization! + # Should get better as we add signatures to our Ops + signature = safe_signature(node.inputs, node.outputs) + return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs)) + + +def vectorize_node(node: Apply, *batched_inputs) -> Apply: + """Returns vectorized version of node with new batched inputs.""" + op = node.op + return _vectorize_node(op, node, *batched_inputs) + + +class Blockwise(Op): + """Generalizes a core `Op` to work with batched dimensions. + + TODO: Dispatch JAX (should be easy with the vectorize macro) + TODO: Dispatch Numba + TODO: C implementation? + TODO: Fuse Blockwise? + """ + + __props__ = ("core_op", "signature") + + def __init__( + self, + core_op: Op, + signature: Optional[str] = None, + name: Optional[str] = None, + **kwargs, + ): + """ + + Parameters + ---------- + core_op + An instance of a subclass of `Op` which works on the core case. + signature + Generalized universal function signature, + e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication + + """ + if isinstance(core_op, Blockwise): + raise TypeError("Core Op is already a Blockwise") + + if signature is None: + signature = getattr(core_op, "gufunc_signature", None) + if signature is None: + raise ValueError( + f"Signature not provided nor found in core_op {core_op}" + ) + + self.core_op = core_op + self.signature = signature + self.name = name + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self._gufunc = None + super().__init__(**kwargs) + + def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: + core_input_types = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if inp.type.ndim < len(sig): + raise ValueError( + f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" + ) + # ndim_supp = 0 case + if not sig: + core_shape = () + else: + core_shape = inp.type.shape[-len(sig) :] + core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape)) + + core_node = self.core_op.make_node(*core_input_types) + + if len(core_node.outputs) != len(self.outputs_sig): + raise ValueError( + f"Insufficient number of outputs for signature {self.signature}: {len(core_node.outputs)}" + ) + for i, (core_out, sig) in enumerate(zip(core_node.outputs, self.outputs_sig)): + if core_out.type.ndim != len(sig): + raise ValueError( + f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" + ) + + return core_node + + def make_node(self, *inputs): + inputs = [as_tensor_variable(i) for i in inputs] + + core_node = self._create_dummy_core_node(inputs) + + batch_ndims = max( + inp.type.ndim - len(sig) for inp, sig in zip(inputs, self.inputs_sig) + ) + + batched_inputs = [] + batch_shapes = [] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + # Append missing dims to the left + missing_batch_ndims = batch_ndims - (inp.type.ndim - len(sig)) + if missing_batch_ndims: + inp = shape_padleft(inp, missing_batch_ndims) + batched_inputs.append(inp) + + if not sig: + batch_shapes.append(inp.type.shape) + else: + batch_shapes.append(inp.type.shape[: -len(sig)]) + + try: + batch_shape = tuple( + [ + broadcast_static_dim_lengths(batch_dims) + for batch_dims in zip(*batch_shapes) + ] + ) + except ValueError: + raise ValueError( + f"Incompatible Blockwise batch input shapes {[inp.type.shape for inp in inputs]}" + ) + + batched_outputs = [ + tensor(dtype=core_out.type.dtype, shape=batch_shape + core_out.type.shape) + for core_out in core_node.outputs + ] + + return Apply(self, batched_inputs, batched_outputs) + + def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: + return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) + + def infer_shape( + self, fgraph, node, input_shapes + ) -> List[Tuple[TensorVariable, ...]]: + from pytensor.tensor import broadcast_shape + from pytensor.tensor.shape import Shape_i + + batch_ndims = self._batch_ndim_from_outputs(node.outputs) + core_dims: Dict[str, Any] = {} + batch_shapes = [] + for input_shape, sig in zip(input_shapes, self.inputs_sig): + batch_shapes.append(input_shape[:batch_ndims]) + core_shape = input_shape[batch_ndims:] + + for core_dim, dim_name in zip(core_shape, sig): + prev_core_dim = core_dims.get(core_dim) + if prev_core_dim is None: + core_dims[dim_name] = core_dim + # Prefer constants + elif not isinstance(prev_core_dim, Constant): + core_dims[dim_name] = core_dim + + batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) + + out_shapes = [] + for output, sig in zip(node.outputs, self.outputs_sig): + core_out_shape = [] + for i, dim_name in enumerate(sig): + # The output dim is the same as another input dim + if dim_name in core_dims: + core_out_shape.append(core_dims[dim_name]) + else: + # TODO: We could try to make use of infer_shape of core_op + core_out_shape.append(Shape_i(batch_ndims + i)(output)) + out_shapes.append((*batch_shape, *core_out_shape)) + + return out_shapes + + def connection_pattern(self, node): + if hasattr(self.core_op, "connection_pattern"): + return self.core_op.connection_pattern(node) + + return [[True for _ in node.outputs] for _ in node.inputs] + + def _bgrad(self, inputs, outputs, ograds): + # Grad, with respect to broadcasted versions of inputs + + def as_core(t, core_t): + # Inputs could be NullType or DisconnectedType + if isinstance(t.type, (NullType, DisconnectedType)): + return t + return core_t.type() + + with config.change_flags(compute_test_value="off"): + safe_inputs = [ + tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) + for inp, sig in zip(inputs, self.inputs_sig) + ] + core_node = self._create_dummy_core_node(safe_inputs) + + core_inputs = [ + as_core(inp, core_inp) + for inp, core_inp in zip(inputs, core_node.inputs) + ] + core_ograds = [ + as_core(ograd, core_ograd) + for ograd, core_ograd in zip(ograds, core_node.outputs) + ] + core_outputs = core_node.outputs + + core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) + + batch_ndims = self._batch_ndim_from_outputs(outputs) + + def transform(var): + # From a graph of ScalarOps, make a graph of Broadcast ops. + if isinstance(var.type, (NullType, DisconnectedType)): + return var + if var in core_inputs: + return inputs[core_inputs.index(var)] + if var in core_outputs: + return outputs[core_outputs.index(var)] + if var in core_ograds: + return ograds[core_ograds.index(var)] + + node = var.owner + + # The gradient contains a constant, which may be responsible for broadcasting + if node is None: + if batch_ndims: + var = shape_padleft(var, batch_ndims) + return var + + batched_inputs = [transform(inp) for inp in node.inputs] + batched_node = vectorize_node(node, *batched_inputs) + batched_var = batched_node.outputs[var.owner.outputs.index(var)] + + return batched_var + + ret = [] + for core_igrad, ipt in zip(core_igrads, inputs): + # Undefined gradient + if core_igrad is None: + ret.append(None) + else: + ret.append(transform(core_igrad)) + + return ret + + def L_op(self, inputs, outs, ograds): + from pytensor.tensor.math import sum as pt_sum + + # Compute grad with respect to broadcasted input + rval = self._bgrad(inputs, outs, ograds) + + # TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable + # to the gradient.grad method when the outputs have + # some integer and some floating point outputs + if any(out.type.dtype not in continuous_dtypes for out in outs): + # For integer output, return value may only be zero or undefined + # We don't bother with trying to check that the scalar ops + # correctly returned something that evaluates to 0, we just make + # the return value obviously zero so that gradient.grad can tell + # this op did the right thing. + new_rval = [] + for elem, inp in zip(rval, inputs): + if isinstance(elem.type, (NullType, DisconnectedType)): + new_rval.append(elem) + else: + elem = inp.zeros_like() + if str(elem.type.dtype) not in continuous_dtypes: + elem = elem.astype(config.floatX) + assert str(elem.type.dtype) not in discrete_dtypes + new_rval.append(elem) + return new_rval + + # Sum out the broadcasted dimensions + batch_ndims = self._batch_ndim_from_outputs(outs) + batch_shape = outs[0].type.shape[:batch_ndims] + for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): + if isinstance(rval[i].type, (NullType, DisconnectedType)): + continue + + assert inp.type.ndim == batch_ndims + len(sig) + + to_sum = [ + j + for j, (inp_s, out_s) in enumerate(zip(inp.type.shape, batch_shape)) + if inp_s == 1 and out_s != 1 + ] + if to_sum: + rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True) + + return rval + + def _create_gufunc(self, node): + if hasattr(self.core_op, "gufunc_spec"): + self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0]) + if self._gufunc: + return self._gufunc + + n_outs = len(self.outputs_sig) + core_node = self._create_dummy_core_node(node.inputs) + + def core_func(*inner_inputs): + inner_outputs = [[None] for _ in range(n_outs)] + + inner_inputs = [np.asarray(inp) for inp in inner_inputs] + self.core_op.perform(core_node, inner_inputs, inner_outputs) + + if len(inner_outputs) == 1: + return inner_outputs[0][0] + else: + return tuple(r[0] for r in inner_outputs) + + self._gufunc = np.vectorize(core_func, signature=self.signature) + return self._gufunc + + def perform(self, node, inputs, output_storage): + gufunc = self._gufunc + + if gufunc is None: + gufunc = self._create_gufunc(node) + + res = gufunc(*inputs) + if not isinstance(res, tuple): + res = (res,) + + for node_out, out_storage, r in zip(node.outputs, output_storage, res): + out_dtype = getattr(node_out, "dtype", None) + if out_dtype and out_dtype != r.dtype: + r = np.asarray(r, dtype=out_dtype) + out_storage[0] = r + + def __str__(self): + if self.name is None: + return f"{type(self).__name__}{{{self.core_op}, {self.signature}}}" + else: + return self.name + + +@_vectorize_node.register(Blockwise) +def vectorize_not_needed(op, node, *batch_inputs): + return op.make_node(*batch_inputs) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index b510a2d3b9..377444609d 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -22,6 +22,7 @@ from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed from pytensor.tensor.type import ( TensorType, continuous_dtypes, @@ -29,6 +30,7 @@ float_dtypes, lvector, ) +from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string from pytensor.tensor.var import TensorVariable from pytensor.utils import uniq @@ -232,7 +234,7 @@ def __str__(self): return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" - def perform(self, node, inp, out, params): + def perform(self, node, inp, out, params=None): (res,) = inp (storage,) = out @@ -429,28 +431,12 @@ def get_output_info(self, dim_shuffle, *inputs): # of all inputs in parallel... the all() gives us each output # broadcastable bit in turn. - def get_most_specialized_shape(shapes): - shapes = set(shapes) - # All shapes are the same - if len(shapes) == 1: - return tuple(shapes)[0] - - # Only valid indeterminate case - if shapes == {None, 1}: - return None - - shapes.discard(1) - shapes.discard(None) - if len(shapes) > 1: - raise ValueError - return tuple(shapes)[0] - # it is multiplied by nout because Elemwise supports multiple outputs # (nout of them) try: out_shapes = [ [ - get_most_specialized_shape(shape) + broadcast_static_dim_lengths(shape) for shape in zip(*[inp.type.shape for inp in inputs]) ] ] * shadow.nout @@ -665,22 +651,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): impl = "c" if getattr(self, "nfunc_spec", None) and impl != "c": - self.nfunc = getattr(np, self.nfunc_spec[0], None) - if self.nfunc is None: - # Not inside NumPy. So probably another package like scipy. - symb = self.nfunc_spec[0].split(".") - for idx in range(1, len(self.nfunc_spec[0])): - try: - module = __import__(".".join(symb[:idx])) - except ImportError: - break - for sub in symb[1:]: - try: - module = getattr(module, sub) - except AttributeError: - module = None - break - self.nfunc = module + self.nfunc = import_func_from_string(self.nfunc_spec[0]) if ( (len(node.inputs) + len(node.outputs)) <= 32 @@ -1768,3 +1739,37 @@ def _get_vector_length_Elemwise(op, var): return get_vector_length(var.owner.inputs[0]) raise ValueError(f"Length of {var} cannot be determined") + + +_vectorize_node.register(Elemwise, vectorize_not_needed) + + +@_vectorize_node.register(DimShuffle) +def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node.op.make_node(x) + input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable + # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) + # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) + new_order = list(range(batched_ndims)) + [ + "x" if (o == "x") else (o + batched_ndims) for o in op.new_order + ] + return DimShuffle(input_broadcastable, new_order).make_node(x) + + +@_vectorize_node.register(CAReduce) +def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: + batched_ndims = x.type.ndim - node.inputs[0].type.ndim + if not batched_ndims: + return node.op.make_node(x) + axes = op.axis + # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) + # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) + if axes is None: + axes = list(range(node.inputs[0].type.ndim)) + else: + axes = list(axes) + new_axes = [axis + batched_ndims for axis in axes] + new_op = op.clone(axis=new_axes) + return new_op.make_node(x) diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index a8c47d5ee8..9461bf440a 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -5,19 +5,25 @@ import pytensor from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Variable +from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op from pytensor.misc.safe_asarray import _asarray from pytensor.scalar import ScalarVariable from pytensor.tensor.basic import ( as_tensor_variable, + concatenate, constant, get_underlying_scalar_constant_value, get_vector_length, infer_static_shape, ) +from pytensor.tensor.blockwise import _vectorize_node from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType -from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes +from pytensor.tensor.random.utils import ( + broadcast_params, + normalize_size_param, + params_broadcast_shapes, +) from pytensor.tensor.shape import shape_tuple from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type_other import NoneConst @@ -383,3 +389,22 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor): default_rng = DefaultGeneratorMakerOp() + + +@_vectorize_node.register(RandomVariable) +def vectorize_random_variable( + op: RandomVariable, node: Apply, rng, size, dtype, *dist_params +) -> Apply: + # If size was provided originally and a new size hasn't been provided, + # We extend it to accommodate the new input batch dimensions. + # Otherwise, we assume the new size already has the right values + old_size = node.inputs[1] + len_old_size = get_vector_length(old_size) + if len_old_size and equal_computations([old_size], [size]): + bcasted_param = broadcast_params(dist_params, op.ndims_params)[0] + new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size + if new_param_ndim >= 0: + new_size_dims = bcasted_param.shape[:new_param_ndim] + size = concatenate([new_size_dims, size]) + + return op.make_node(rng, size, dtype, *dist_params) diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 80946d524c..617eab04fa 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -2,6 +2,7 @@ import pytensor.tensor.rewriting.blas import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_scipy +import pytensor.tensor.rewriting.blockwise import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py new file mode 100644 index 0000000000..c85fba3815 --- /dev/null +++ b/pytensor/tensor/rewriting/blockwise.py @@ -0,0 +1,41 @@ +from pytensor.compile.mode import optdb +from pytensor.graph import node_rewriter +from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.tensor.blockwise import Blockwise, vectorize_node + + +@node_rewriter([Blockwise]) +def local_useless_blockwise(fgraph, node): + """ + If there is a dispatch implementation that does not require Blockwise, use that instead. + This means a user created a Blockwise manually when there was no need. + + Note: This rewrite is not registered by default anywhere + """ + op = node.op + inputs = node.inputs + dummy_core_node = op._create_dummy_core_node(node.inputs) + vect_node = vectorize_node(dummy_core_node, *inputs) + if not isinstance(vect_node.op, Blockwise): + return copy_stack_trace(node.outputs, vect_node.outputs) + + +@node_rewriter([Blockwise]) +def local_useless_unbatched_blockwise(fgraph, node): + """Remove Blockwise that don't have any batched dims.""" + op = node.op + inputs = node.inputs + + if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0: + return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs) + + +# We register this rewrite late, so that other rewrites need only target Blockwise Ops +optdb.register( + "local_useless_unbatched_blockwise", + out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), + "fast_run", + "fast_compile", + "blockwise", + position=49, +) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 7535f47c5c..2150587180 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,3 +1,5 @@ +from typing import Sequence, Union + import numpy as np import pytensor @@ -107,3 +109,54 @@ def as_list(x): return list(x) except TypeError: return [x] + + +def import_func_from_string(func_string: str): # -> Optional[Callable]: + func = getattr(np, func_string, None) + if func is not None: + return func + + # Not inside NumPy or Scipy. So probably another package like scipy. + module = None + items = func_string.split(".") + for idx in range(1, len(items)): + try: + module = __import__(".".join(items[:idx])) + except ImportError: + break + + if module: + for sub in items[1:]: + try: + module = getattr(module, sub) + except AttributeError: + module = None + break + return module + + +def broadcast_static_dim_lengths( + dim_lengths: Sequence[Union[int, None]] +) -> Union[int, None]: + """Apply static broadcast given static dim length of inputs (obtained from var.type.shape). + + Raises + ------ + ValueError + When static dim lengths are incompatible + """ + + dim_lengths_set = set(dim_lengths) + # All dim_lengths are the same + if len(dim_lengths_set) == 1: + return tuple(dim_lengths_set)[0] + + # Only valid indeterminate case + if dim_lengths_set == {None, 1}: + return None + + dim_lengths_set.discard(1) + dim_lengths_set.discard(None) + if len(dim_lengths_set) > 1: + raise ValueError + return tuple(dim_lengths_set)[0] diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 0eec50e5a6..0bc8f0a73f 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -5,7 +5,9 @@ from pytensor import config, function from pytensor.gradient import NullTypeGradError, grad from pytensor.raise_op import Assert +from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.math import eq +from pytensor.tensor.random import normal from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.shape import specify_shape from pytensor.tensor.type import all_dtypes, iscalar, tensor @@ -202,3 +204,37 @@ def test_RandomVariable_incompatible_size(): ValueError, match="Size length is incompatible with batched dimensions" ): rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) + + +def test_vectorize_node(): + vec = tensor(shape=(None,)) + vec.tag.test_value = [0, 0, 0] + mat = tensor(shape=(None, None)) + mat.tag.test_value = [[0, 0, 0], [1, 1, 1]] + + # Test without size + node = normal(vec).owner + new_inputs = node.inputs.copy() + new_inputs[3] = mat + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.inputs[3] is mat + + # Test with size, new size provided + node = normal(vec, size=(3,)).owner + new_inputs = node.inputs.copy() + new_inputs[1] = (2, 3) + new_inputs[3] = mat + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert tuple(vect_node.inputs[1].eval()) == (2, 3) + assert vect_node.inputs[3] is mat + + # Test with size, new size not provided + node = normal(vec, size=(3,)).owner + new_inputs = node.inputs.copy() + new_inputs[3] = mat + vect_node = vectorize_node(node, *new_inputs) + assert vect_node.op is normal + assert vect_node.inputs[3] is mat + assert tuple(vect_node.inputs[1].eval({mat: mat.tag.test_value})) == (2, 3) diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py new file mode 100644 index 0000000000..6a01e4ed2a --- /dev/null +++ b/tests/tensor/rewriting/test_blockwise.py @@ -0,0 +1,36 @@ +from pytensor import function +from pytensor.scalar import log as scalar_log +from pytensor.tensor import matrix, tensor3 +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.nlinalg import MatrixPinv + + +def test_useless_blockwise_of_elemwise(): + x = matrix("x") + out = Blockwise(Elemwise(scalar_log), signature="()->()")(x) + + assert isinstance(out.owner.op, Blockwise) + assert isinstance(out.owner.op.core_op, Elemwise) + + fn = function([x], out, mode="FAST_COMPILE") + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Elemwise) + + +def test_useless_unbatched_blockwise(): + x = matrix("x") + blockwise_op = Blockwise(MatrixPinv(hermitian=False), signature="(m,n)->(n,m)") + out = blockwise_op(x) + + assert isinstance(out.owner.op, Blockwise) + assert isinstance(out.owner.op.core_op, MatrixPinv) + + fn = function([x], out, mode="FAST_COMPILE") + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv) + + # Test that it's not removed when there are batched dims + x = tensor3("x") + out = blockwise_op(x) + fn = function([x], out, mode="FAST_COMPILE") + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) + assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py new file mode 100644 index 0000000000..437e3bbc22 --- /dev/null +++ b/tests/tensor/test_blockwise.py @@ -0,0 +1,258 @@ +from itertools import product +from typing import Optional, Tuple, Union + +import numpy as np +import pytest + +import pytensor +from pytensor import config +from pytensor.gradient import grad +from pytensor.graph import Apply, Op +from pytensor.tensor import tensor +from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node +from pytensor.tensor.nlinalg import MatrixInverse +from pytensor.tensor.slinalg import Cholesky, Solve + + +def test_vectorize_blockwise(): + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + # Something that falls back to Blockwise + node = MatrixInverse()(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, MatrixInverse + ) + assert vect_node.inputs[0] is tns + + # Useless blockwise + tns4 = tensor(shape=(5, None, None, None)) + new_vect_node = vectorize_node(vect_node, tns4) + assert new_vect_node.op is vect_node.op + assert isinstance(new_vect_node.op, Blockwise) and isinstance( + new_vect_node.op.core_op, MatrixInverse + ) + assert new_vect_node.inputs[0] is tns4 + + +class TestOp(Op): + def make_node(self, *inputs): + return Apply(self, inputs, [i.type() for i in inputs]) + + def perform(self, *args, **kwargs): + raise NotImplementedError("Test Op should not be present in final graph") + + +test_op = TestOp() + + +def test_vectorize_node_default_signature(): + vec = tensor(shape=(None,)) + mat = tensor(shape=(5, None)) + node = test_op.make_node(vec, mat) + + vect_node = vectorize_node(node, mat, mat) + assert isinstance(vect_node.op, Blockwise) and isinstance( + vect_node.op.core_op, TestOp + ) + assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)") + + with pytest.raises( + ValueError, match="Signature not provided nor found in core_op TestOp" + ): + Blockwise(test_op) + + vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + assert vect_node.outputs[0].type.shape == ( + 5, + None, + ) + + +def test_blockwise_shape(): + # Single output + inp = tensor(shape=(5, None, None)) + inp_test = np.zeros((5, 4, 3), dtype=config.floatX) + + # Shape can be inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (n, m)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 3, 4) + + # Shape can only be partially inferred from inputs + op = Blockwise(test_op, signature="(m, n) -> (m, k)") + out = op(inp) + assert out.type.shape == (5, None, None) + + shape_fn = pytensor.function([inp], out.shape) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp], out.shape[:-1]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp_test)) == (5, 4) + + # Mutiple outputs + inp1 = tensor(shape=(7, 1, None, None)) + inp2 = tensor(shape=(1, 5, None, None)) + inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX) + inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX) + + op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") + outs = op(inp1, inp2) + assert outs[0].type.shape == (7, 5, None, None) + assert outs[1].type.shape == (7, 5, None, None) + + shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs]) + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + + shape_fn = pytensor.function([inp1, inp2], outs[0].shape) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4) + + shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]]) + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOp) + for n in shape_fn.maker.fgraph.apply_nodes + ) + assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4) + assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4) + + +class BlockwiseOpTester: + """Base class to test Blockwise works for specific Ops""" + + core_op = None + signature = None + batcheable_axes = None + + @classmethod + def setup_class(cls): + seed = sum(map(ord, str(cls.core_op))) + cls.rng = np.random.default_rng(seed) + cls.params_sig, cls.outputs_sig = _parse_gufunc_signature(cls.signature) + if cls.batcheable_axes is None: + cls.batcheable_axes = list(range(len(cls.params_sig))) + batch_shapes = [(), (1,), (5,), (1, 1), (1, 5), (3, 1), (3, 5)] + cls.test_batch_shapes = list( + product(batch_shapes, repeat=len(cls.batcheable_axes)) + ) + cls.block_op = Blockwise(core_op=cls.core_op, signature=cls.signature) + + @staticmethod + def parse_shape(shape: Tuple[Union[str, int], ...]) -> Tuple[int, ...]: + """ + Convert (5, "m", "n") -> (5, 7, 11) + """ + mapping = {"m": 7, "n": 11, "k": 19} + return tuple(mapping.get(p, p) for p in shape) + + def create_testvals(self, shape): + return self.rng.normal(size=self.parse_shape(shape)).astype(config.floatX) + + def create_batched_inputs(self, batch_idx: Optional[int] = None): + for batch_shapes in self.test_batch_shapes: + vec_inputs = [] + vec_inputs_testvals = [] + for idx, (batch_shape, param_sig) in enumerate( + zip(batch_shapes, self.params_sig) + ): + if batch_idx is not None and idx != batch_idx: + # Skip out combinations in which other inputs are batched + if batch_shape != (): + break + vec_inputs.append(tensor(shape=batch_shape + (None,) * len(param_sig))) + vec_inputs_testvals.append( + self.create_testvals(shape=batch_shape + param_sig) + ) + else: # no-break + yield vec_inputs, vec_inputs_testvals + + def test_perform(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + core_func = pytensor.function(base_inputs, self.core_op(*base_inputs)) + np_func = np.vectorize(core_func, signature=self.signature) + + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs(): + pt_func = pytensor.function(vec_inputs, self.block_op(*vec_inputs)) + if len(self.outputs_sig) != 1: + raise NotImplementedError("Did not implement test for multi-output Ops") + np.testing.assert_allclose( + pt_func(*vec_inputs_testvals), + np_func(*vec_inputs_testvals), + ) + + def test_grad(self): + base_inputs = [ + tensor(shape=(None,) * len(param_sig)) for param_sig in self.params_sig + ] + out = self.core_op(*base_inputs).sum() + # Create separate numpy vectorized functions for each input + np_funcs = [] + for i, inp in enumerate(base_inputs): + core_grad_func = pytensor.function(base_inputs, grad(out, wrt=inp)) + params_sig = self.signature.split("->")[0] + param_sig = f"({','.join(self.params_sig[i])})" + grad_sig = f"{params_sig}->{param_sig}" + np_func = np.vectorize(core_grad_func, signature=grad_sig) + np_funcs.append(np_func) + + # We test gradient wrt to one batched input at a time + for test_input_idx in range(len(base_inputs)): + for vec_inputs, vec_inputs_testvals in self.create_batched_inputs( + batch_idx=test_input_idx + ): + out = self.block_op(*vec_inputs).sum() + pt_func = pytensor.function( + vec_inputs, grad(out, wrt=vec_inputs[test_input_idx]) + ) + pt_out = pt_func(*vec_inputs_testvals) + np_out = np_funcs[test_input_idx](*vec_inputs_testvals) + np.testing.assert_allclose(pt_out, np_out, atol=1e-6) + + +class MatrixOpBlockwiseTester(BlockwiseOpTester): + def create_testvals(self, shape): + # Return a posdef matrix + X = super().create_testvals(shape) + return np.einsum("...ij,...kj->...ik", X, X) + + +class TestCholesky(MatrixOpBlockwiseTester): + core_op = Cholesky(lower=True) + signature = "(m, m) -> (m, m)" + + +class TestMatrixInverse(MatrixOpBlockwiseTester): + core_op = MatrixInverse() + signature = "(m, m) -> (m, m)" + + +class TestSolve(BlockwiseOpTester): + core_op = Solve(lower=True) + signature = "(m, m),(m) -> (m)" diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 52f34c7be0..3d3aa1b28d 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -17,10 +17,13 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second +from pytensor.tensor.blockwise import vectorize_node from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import all as at_all -from pytensor.tensor.math import any as at_any +from pytensor.tensor.math import Any, Sum +from pytensor.tensor.math import all as pt_all +from pytensor.tensor.math import any as pt_any from pytensor.tensor.math import exp +from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( TensorType, bmatrix, @@ -470,12 +473,12 @@ def with_mode( axis2.append(a) assert len(axis2) == len(tosum) tosum = tuple(axis2) - if tensor_op == at_all: + if tensor_op == pt_all: for axis in sorted(tosum, reverse=True): zv = np.all(zv, axis) if len(tosum) == 0: zv = zv != 0 - elif tensor_op == at_any: + elif tensor_op == pt_any: for axis in sorted(tosum, reverse=True): zv = np.any(zv, axis) if len(tosum) == 0: @@ -553,8 +556,8 @@ def test_perform(self): self.with_mode(Mode(linker="py"), aes.mul, dtype=dtype) self.with_mode(Mode(linker="py"), aes.scalar_maximum, dtype=dtype) self.with_mode(Mode(linker="py"), aes.scalar_minimum, dtype=dtype) - self.with_mode(Mode(linker="py"), aes.and_, dtype=dtype, tensor_op=at_all) - self.with_mode(Mode(linker="py"), aes.or_, dtype=dtype, tensor_op=at_any) + self.with_mode(Mode(linker="py"), aes.and_, dtype=dtype, tensor_op=pt_all) + self.with_mode(Mode(linker="py"), aes.or_, dtype=dtype, tensor_op=pt_any) for dtype in ["int8", "uint8"]: self.with_mode(Mode(linker="py"), aes.or_, dtype=dtype) self.with_mode(Mode(linker="py"), aes.and_, dtype=dtype) @@ -575,14 +578,14 @@ def test_perform_nan(self): aes.or_, dtype=dtype, test_nan=True, - tensor_op=at_any, + tensor_op=pt_any, ) self.with_mode( Mode(linker="py"), aes.and_, dtype=dtype, test_nan=True, - tensor_op=at_all, + tensor_op=pt_all, ) @pytest.mark.skipif( @@ -606,8 +609,8 @@ def test_c(self): for dtype in ["bool", "floatX", "int8", "uint8"]: self.with_mode(Mode(linker="c"), aes.scalar_minimum, dtype=dtype) self.with_mode(Mode(linker="c"), aes.scalar_maximum, dtype=dtype) - self.with_mode(Mode(linker="c"), aes.and_, dtype=dtype, tensor_op=at_all) - self.with_mode(Mode(linker="c"), aes.or_, dtype=dtype, tensor_op=at_any) + self.with_mode(Mode(linker="c"), aes.and_, dtype=dtype, tensor_op=pt_all) + self.with_mode(Mode(linker="c"), aes.or_, dtype=dtype, tensor_op=pt_any) for dtype in ["bool", "int8", "uint8"]: self.with_mode(Mode(linker="c"), aes.or_, dtype=dtype) self.with_mode(Mode(linker="c"), aes.and_, dtype=dtype) @@ -915,3 +918,50 @@ def grad(self, inputs, gout): # Verify that trying to use the not implemented gradient fails. with pytest.raises(pytensor.gradient.NullTypeGradError): pytensor.gradient.grad(test_op(x, 2), x) + + +class TestVectorize: + def test_elemwise(self): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + node = exp(vec).owner + vect_node = vectorize_node(node, mat) + assert vect_node.op == exp + assert vect_node.inputs[0] is mat + + def test_dimshuffle(self): + vec = tensor(shape=(None,)) + mat = tensor(shape=(None, None)) + + node = exp(vec).owner + vect_node = vectorize_node(node, mat) + assert vect_node.op == exp + assert vect_node.inputs[0] is mat + + col_mat = tensor(shape=(None, 1)) + tcol_mat = tensor(shape=(None, None, 1)) + node = col_mat.dimshuffle(0).owner # drop column + vect_node = vectorize_node(node, tcol_mat) + assert isinstance(vect_node.op, DimShuffle) + assert vect_node.op.new_order == (0, 1) + assert vect_node.inputs[0] is tcol_mat + assert vect_node.outputs[0].type.shape == (None, None) + + def test_CAReduce(self): + mat = tensor(shape=(None, None)) + tns = tensor(shape=(None, None, None)) + + node = pt_sum(mat).owner + vect_node = vectorize_node(node, tns) + assert isinstance(vect_node.op, Sum) + assert vect_node.op.axis == (1, 2) + assert vect_node.inputs[0] is tns + + bool_mat = tensor(dtype="bool", shape=(None, None)) + bool_tns = tensor(dtype="bool", shape=(None, None, None)) + node = pt_any(bool_mat, axis=-2).owner + vect_node = vectorize_node(node, bool_tns) + assert isinstance(vect_node.op, Any) + assert vect_node.op.axis == (1,) + assert vect_node.inputs[0] is bool_tns