Skip to content

Commit

Permalink
Implement Blockwise Op to vectorize existing Ops
Browse files Browse the repository at this point in the history
Inspired by: aesara-devs/aesara#1215

Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com>
Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com>
Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3>
Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
  • Loading branch information
5 people committed Jun 23, 2023
1 parent 117d011 commit 7256c27
Show file tree
Hide file tree
Showing 8 changed files with 872 additions and 17 deletions.
430 changes: 430 additions & 0 deletions pytensor/tensor/blockwise.py

Large diffs are not rendered by default.

55 changes: 38 additions & 17 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
discrete_dtypes,
float_dtypes,
lvector,
)
from pytensor.tensor.utils import import_func_from_string
from pytensor.tensor.var import TensorVariable
from pytensor.utils import uniq

Expand Down Expand Up @@ -228,7 +230,7 @@ def __str__(self):
return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"

def perform(self, node, inp, out, params):
def perform(self, node, inp, out, params=None):
(res,) = inp
(storage,) = out

Expand Down Expand Up @@ -662,22 +664,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
impl = "c"

if getattr(self, "nfunc_spec", None) and impl != "c":
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__(".".join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
try:
module = getattr(module, sub)
except AttributeError:
module = None
break
self.nfunc = module
self.nfunc = import_func_from_string(self.nfunc_spec[0])

if (
(len(node.inputs) + len(node.outputs)) <= 32
Expand Down Expand Up @@ -1759,3 +1746,37 @@ def _get_vector_length_Elemwise(op, var):
return get_vector_length(var.owner.inputs[0])

raise ValueError(f"Length of {var} cannot be determined")


_vectorize_node.register(Elemwise, vectorize_not_needed)


@_vectorize_node.register(DimShuffle)
def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
new_order = list(range(batched_ndims)) + [
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
]
return DimShuffle(input_broadcastable, new_order).make_node(x)


@_vectorize_node.register(CAReduce)
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
axes = op.axis
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
if axes is None:
axes = list(range(node.inputs[0].type.ndim))
else:
axes = list(axes)
new_axes = [axis + batched_ndims for axis in axes]
new_op = op.clone(axis=new_axes)
return new_op.make_node(x)
6 changes: 6 additions & 0 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_vector_length,
infer_static_shape,
)
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from pytensor.tensor.shape import shape_tuple
Expand Down Expand Up @@ -428,3 +429,8 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):


default_rng = DefaultGeneratorMakerOp()


# RandomVariables are vectorized on the parameters by default.
# RNG, size and dtype can't be vectorized, but the Op will raise if the wrong input type is passed
_vectorize_node.register(RandomVariable, vectorize_not_needed)
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops

Expand Down
39 changes: 39 additions & 0 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise, vectorize_node
from pytensor.tensor.rewriting.basic import register_useless


@register_useless("fast_compile")
@node_rewriter([Blockwise])
def local_useless_blockwise(fgraph, node):
# If there is a dispatch implementation that does not require Blockwise, use that instead.
# This means a user created a Blockwise manually when there was no need.
op = node.op
inputs = node.inputs
dummy_core_node = op._create_dummy_core_node(node.inputs)
vect_node = vectorize_node(dummy_core_node, *inputs)
if not isinstance(vect_node.op, Blockwise):
return copy_stack_trace(node.outputs, vect_node.outputs)


@node_rewriter([Blockwise])
def local_useless_unbatched_blockwise(fgraph, node):
"""Remove Blockwise that don't have any batched dims."""
op = node.op
inputs = node.inputs

if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0:
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)


# We register this rewrite late, so that other rewrites need only target Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
position=49,
)
24 changes: 24 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,27 @@ def as_list(x):
return list(x)
except TypeError:
return [x]


def import_func_from_string(func_string: str): # -> Optional[Callable]:
func = getattr(np, func_string, None)
if func is not None:
return func

# Not inside NumPy or Scipy. So probably another package like scipy.
module = None
items = func_string.split(".")
for idx in range(1, len(items)):
try:
module = __import__(".".join(items[:idx]))
except ImportError:
break

if module:
for sub in items[1:]:
try:
module = getattr(module, sub)
except AttributeError:
module = None
break
return module
36 changes: 36 additions & 0 deletions tests/tensor/rewriting/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pytensor import function
from pytensor.scalar import log as scalar_log
from pytensor.tensor import matrix, tensor3
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv


def test_useless_blockwise_of_elemwise():
x = matrix("x")
out = Blockwise(Elemwise(scalar_log), signature="()->()")(x)

assert isinstance(out.owner.op, Blockwise)
assert isinstance(out.owner.op.core_op, Elemwise)

fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Elemwise)


def test_useless_unbatched_blockwise():
x = matrix("x")
blockwise_op = Blockwise(MatrixPinv(hermitian=False), signature="(m,n)->(n,m)")
out = blockwise_op(x)

assert isinstance(out.owner.op, Blockwise)
assert isinstance(out.owner.op.core_op, MatrixPinv)

fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv)

# Test that it's not removed when there are batched dims
x = tensor3("x")
out = blockwise_op(x)
fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
Loading

0 comments on commit 7256c27

Please sign in to comment.